[CI/Build] Split up models tests (#10069)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-11-10 03:39:14 +08:00 committed by GitHub
parent b09895a618
commit 51c2e1fcef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 115 additions and 129 deletions

View File

@ -305,7 +305,7 @@ steps:
##### models test ##### ##### models test #####
- label: Basic Models Test # 3min - label: Basic Models Test # 10min
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models - tests/models
@ -314,23 +314,24 @@ steps:
- pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s models/*.py --ignore=models/test_oot_registration.py - pytest -v -s models/*.py --ignore=models/test_oot_registration.py
- label: Decoder-only Language Models Test (Standard) # 35min - label: Decoder-only Language Models Test (Standard) # 18min
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models/decoder_only/language - tests/models/decoder_only/language
commands: commands:
- pytest -v -s models/decoder_only/language/test_models.py - pytest -v -s models/decoder_only/language -m core_model
- pytest -v -s models/decoder_only/language -m quant_model
- label: Decoder-only Language Models Test (Extended) # 1h20min - label: Decoder-only Language Models Test (Extended) # 46min
nightly: true nightly: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models/decoder_only/language - tests/models/decoder_only/language
commands: commands:
- pytest -v -s models/decoder_only/language --ignore=models/decoder_only/language/test_models.py - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- label: Decoder-only Multi-Modal Models Test (Standard) # 26min - label: Decoder-only Multi-Modal Models Test (Standard) # 22min
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
@ -339,21 +340,24 @@ steps:
commands: commands:
- pytest -v -s models/decoder_only/audio_language -m core_model - pytest -v -s models/decoder_only/audio_language -m core_model
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m core_model - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m core_model
# No tests under this group for now
# - pytest -v -s models/decoder_only/audio_language -m quant_model
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m quant_model
- label: Decoder-only Multi-Modal Models Test (Extended) - label: Decoder-only Multi-Modal Models Test (Extended) # 1h10m
nightly: true nightly: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models/decoder_only/audio_language - tests/models/decoder_only/audio_language
- tests/models/decoder_only/vision_language - tests/models/decoder_only/vision_language
commands: commands:
- pytest -v -s models/decoder_only/audio_language -m 'not core_model' - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
# HACK - run phi3v tests separately to sidestep this transformers bug # HACK - run phi3v tests separately to sidestep this transformers bug
# https://github.com/huggingface/transformers/issues/34307 # https://github.com/huggingface/transformers/issues/34307
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py - pytest -v -s models/decoder_only/vision_language/test_phi3v.py
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
- label: Other Models Test # 6min - label: Other Models Test # 20min
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/

View File

@ -95,6 +95,7 @@ markers = [
"skip_global_cleanup", "skip_global_cleanup",
"core_model: enable this model test in each PR instead of only nightly", "core_model: enable this model test in each PR instead of only nightly",
"cpu_model: enable this model test in CPU tests", "cpu_model: enable this model test in CPU tests",
"quant_model: run this model test under Quantized category",
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs", "distributed_2_gpus: run this test only in distributed tests for 2 GPUs",
"skip_v1: do not run this test with v1", "skip_v1: do not run this test with v1",
] ]

View File

@ -38,6 +38,7 @@ ground_truth_generations = [
] ]
@pytest.mark.quant_model
@pytest.mark.skipif(not is_quant_method_supported("aqlm"), @pytest.mark.skipif(not is_quant_method_supported("aqlm"),
reason="AQLM is not supported on this GPU type.") reason="AQLM is not supported on this GPU type.")
@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"]) @pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])

View File

@ -15,6 +15,7 @@ from ...utils import check_logprobs_close
os.environ["TOKENIZERS_PARALLELISM"] = "true" os.environ["TOKENIZERS_PARALLELISM"] = "true"
@pytest.mark.quant_model
@pytest.mark.skipif(not is_quant_method_supported("fp8"), @pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="fp8 is not supported on this GPU type.") reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -17,26 +17,21 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN = 1024 MAX_MODEL_LEN = 1024
# FIXME: Move this to confest
MODELS = [
("meta-llama/Llama-3.2-1B-Instruct",
hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
filename="Llama-3.2-1B-Instruct-Q4_K_M.gguf")),
("meta-llama/Llama-3.2-1B-Instruct",
hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf")),
("Qwen/Qwen2-1.5B-Instruct",
hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF",
filename="qwen2-1_5b-instruct-q4_k_m.gguf")),
("Qwen/Qwen2-1.5B-Instruct",
hf_hub_download("legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
filename="Qwen2-1.5B-Instruct.IQ4_XS.gguf")),
]
@pytest.mark.skipif(not is_quant_method_supported("gguf"), @pytest.mark.skipif(not is_quant_method_supported("gguf"),
reason="gguf is not supported on this GPU type.") reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize(("original_model", "gguf_id", "gguf_path"), [
("meta-llama/Llama-3.2-1B-Instruct",
"bartowski/Llama-3.2-1B-Instruct-GGUF",
"Llama-3.2-1B-Instruct-Q4_K_M.gguf"),
("meta-llama/Llama-3.2-1B-Instruct",
"bartowski/Llama-3.2-1B-Instruct-GGUF",
"Llama-3.2-1B-Instruct-IQ4_XS.gguf"),
("Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-1.5B-Instruct-GGUF",
"qwen2-1_5b-instruct-q4_k_m.gguf"),
("Qwen/Qwen2-1.5B-Instruct", "legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
"Qwen2-1.5B-Instruct.IQ4_XS.gguf"),
])
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@ -45,7 +40,9 @@ def test_models(
num_gpus_available, num_gpus_available,
vllm_runner, vllm_runner,
example_prompts, example_prompts,
model, original_model,
gguf_id,
gguf_path,
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
@ -54,7 +51,7 @@ def test_models(
if num_gpus_available < tp_size: if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
original_model, gguf_model = model gguf_model = hf_hub_download(gguf_id, filename=gguf_path)
tokenizer = AutoTokenizer.from_pretrained(original_model) tokenizer = AutoTokenizer.from_pretrained(original_model)
messages = [[{ messages = [[{

View File

@ -33,6 +33,7 @@ MODELS = [
] ]
@pytest.mark.quant_model
@pytest.mark.flaky(reruns=3) @pytest.mark.flaky(reruns=3)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="gptq_marlin is not supported on this GPU type.") reason="gptq_marlin is not supported on this GPU type.")

View File

@ -38,6 +38,7 @@ model_pairs = [
] ]
@pytest.mark.quant_model
@pytest.mark.flaky(reruns=2) @pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24"), @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24"),
reason="Marlin24 is not supported on this GPU type.") reason="Marlin24 is not supported on this GPU type.")

View File

@ -7,7 +7,9 @@ import pytest
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
MODELS = [ MODELS = [
# TODO(sang): Sliding window should be tested separately.
"ibm/PowerLM-3b", "ibm/PowerLM-3b",
"ibm/PowerMoE-3b",
] ]
@ -24,7 +26,6 @@ def test_models(
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
) -> None: ) -> None:
# TODO(sang): Sliding window should be tested separately.
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit( hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)

View File

@ -1,39 +0,0 @@
"""Compare the outputs of HF and vLLM for Granite models using greedy sampling.
Run `pytest tests/models/test_granite.py`.
"""
import pytest
from ...utils import check_logprobs_close
MODELS = [
"ibm/PowerMoE-3b",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)

View File

@ -39,6 +39,7 @@ EXPECTED_STRS_MAP = {
@pytest.mark.skip( @pytest.mark.skip(
reason= reason=
"Prevent unstable test based on golden strings from breaking the build.") "Prevent unstable test based on golden strings from breaking the build.")
@pytest.mark.quant_model
@pytest.mark.skipif(not is_quant_method_supported("fp8"), @pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="fp8 is not supported on this GPU type.") reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS) @pytest.mark.parametrize("model_name", MODELS)

View File

@ -1,8 +1,5 @@
"""Compare the outputs of HF and vLLM when using greedy sampling. """Compare the outputs of HF and vLLM when using greedy sampling.
This test only tests small models. Big models such as 7B should be tested from
test_big_models.py because it could use a larger instance to run tests.
Run `pytest tests/models/test_models.py`. Run `pytest tests/models/test_models.py`.
""" """
import pytest import pytest
@ -35,6 +32,7 @@ if not current_platform.is_cpu():
target_dtype = "half" target_dtype = "half"
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])

View File

@ -56,11 +56,13 @@ def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next,
ctx.model_config.hf_config.image_grid_pinpoints = gridpoints ctx.model_config.hf_config.image_grid_pinpoints = gridpoints
seq_len = 5000 # bigger than the max feature size for any image seq_len = 5000 # bigger than the max feature size for any image
seq_data, mm_data = dummy_data_for_llava_next( dummy_data = dummy_data_for_llava_next(
ctx, ctx,
seq_len=seq_len, seq_len=seq_len,
mm_counts={"image": 1}, mm_counts={"image": 1},
) )
seq_data = dummy_data.seq_data
mm_data = dummy_data.multi_modal_data
# The dummy data dims should match the gridpoint with the biggest feat size # The dummy data dims should match the gridpoint with the biggest feat size
assert mm_data["image"].height == expected_size[0] assert mm_data["image"].height == expected_size[0]

View File

@ -131,12 +131,13 @@ def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int,
mm_processor_kwargs=None, mm_processor_kwargs=None,
) )
sequence_data, _, = dummy_data_for_phi3v( dummy_data = dummy_data_for_phi3v(
ctx=ctx, ctx=ctx,
seq_len=8192, # Should be bigger than num_imgs * toks_per_img seq_len=8192, # Should be bigger than num_imgs * toks_per_img
mm_counts={"image": num_imgs}, mm_counts={"image": num_imgs},
num_crops=num_crops, num_crops=num_crops,
) )
sequence_data = dummy_data.seq_data
# Ensure we have the right number of placeholders per num_crops size # Ensure we have the right number of placeholders per num_crops size
img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID) img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID)
assert img_tok_count == toks_per_img * num_imgs assert img_tok_count == toks_per_img * num_imgs

View File

@ -86,10 +86,17 @@ def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl,
# NOTE: video value is required, but isn't actually used # NOTE: video value is required, but isn't actually used
# when making the dummy data except for error handling currently # when making the dummy data except for error handling currently
seq_data, mm_data = dummy_data_for_qwen2_vl(qwen2_vl_context, seq_len, { dummy_data = dummy_data_for_qwen2_vl(
ctx=qwen2_vl_context,
seq_len=seq_len,
mm_counts={
"image": 1, "image": 1,
"video": 0 "video": 0
}, **mm_processor_kwargs) },
**mm_processor_kwargs,
)
seq_data = dummy_data.seq_data
mm_data = dummy_data.multi_modal_data
# Ensure we have the right number of placeholders for min/max pixel values # Ensure we have the right number of placeholders for min/max pixel values
assert seq_data.get_token_ids().count(image_token_id) == token_count assert seq_data.get_token_ids().count(image_token_id) == token_count

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Type from typing import List, Optional, Type
import pytest import pytest
import torch import torch
@ -19,7 +19,8 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
def run_awq_test( def run_awq_test(
vllm_runner: Type[VllmRunner], vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets, image_assets: _ImageAssets,
models: Tuple[str, str], source_model: str,
quant_model: str,
*, *,
size_factors: List[float], size_factors: List[float],
dtype: str, dtype: str,
@ -28,8 +29,6 @@ def run_awq_test(
tensor_parallel_size: int, tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None, distributed_executor_backend: Optional[str] = None,
): ):
source_model, quant_model = models
images = [asset.pil_image for asset in image_assets] images = [asset.pil_image for asset in image_assets]
inputs_per_image = [( inputs_per_image = [(
@ -84,8 +83,11 @@ def run_awq_test(
) )
@pytest.mark.quant_model
@pytest.mark.parametrize( @pytest.mark.parametrize(
"models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")]) ("source_model", "quant_model"),
[("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"size_factors", "size_factors",
[ [
@ -103,12 +105,13 @@ def run_awq_test(
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@torch.inference_mode() @torch.inference_mode()
def test_awq_models(vllm_runner, image_assets, models, size_factors, def test_awq_models(vllm_runner, image_assets, source_model, quant_model,
dtype: str, max_tokens: int, num_logprobs: int) -> None: size_factors, dtype, max_tokens, num_logprobs) -> None:
run_awq_test( run_awq_test(
vllm_runner, vllm_runner,
image_assets, image_assets,
models, source_model,
quant_model,
size_factors=size_factors, size_factors=size_factors,
dtype=dtype, dtype=dtype,
max_tokens=max_tokens, max_tokens=max_tokens,

View File

@ -11,21 +11,17 @@ from ....conftest import _ImageAssets
# we use snapshot_download to prevent conflicts between # we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner # dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
models = [
snapshot_download("OpenGVLab/InternViT-300M-448px",
allow_patterns=DOWNLOAD_PATTERN),
snapshot_download("OpenGVLab/InternViT-6B-448px-V1-5",
allow_patterns=DOWNLOAD_PATTERN),
]
def run_intern_vit_test( def run_intern_vit_test(
image_assets: _ImageAssets, image_assets: _ImageAssets,
model: str, model_id: str,
*, *,
dtype: str, dtype: str,
distributed_executor_backend: Optional[str] = None, distributed_executor_backend: Optional[str] = None,
): ):
model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
img_processor = CLIPImageProcessor.from_pretrained(model) img_processor = CLIPImageProcessor.from_pretrained(model)
images = [asset.pil_image for asset in image_assets] images = [asset.pil_image for asset in image_assets]
pixel_values = [ pixel_values = [
@ -67,12 +63,15 @@ def run_intern_vit_test(
assert cos_similar(vllm_output, hf_output).mean() > 0.99 assert cos_similar(vllm_output, hf_output).mean() > 0.99
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model_id", [
"OpenGVLab/InternViT-300M-448px",
"OpenGVLab/InternViT-6B-448px-V1-5",
])
@pytest.mark.parametrize("dtype", [torch.half]) @pytest.mark.parametrize("dtype", [torch.half])
@torch.inference_mode() @torch.inference_mode()
def test_models(dist_init, image_assets, model, dtype: str) -> None: def test_models(dist_init, image_assets, model_id, dtype: str) -> None:
run_intern_vit_test( run_intern_vit_test(
image_assets, image_assets,
model, model_id,
dtype=dtype, dtype=dtype,
) )

View File

@ -130,8 +130,8 @@ VLM_TEST_SETTINGS = {
max_num_seqs=2, max_num_seqs=2,
auto_cls=AutoModelForVision2Seq, auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
), ),
#### Extended model tests #### Extended model tests
"blip2": VLMTestInfo( "blip2": VLMTestInfo(
@ -159,9 +159,9 @@ VLM_TEST_SETTINGS = {
dtype="bfloat16", dtype="bfloat16",
marks=[ marks=[
pytest.mark.skipif( pytest.mark.skipif(
transformers.__version__.startswith("4.46"), transformers.__version__ < "4.46.2",
reason="Model broken in HF, see huggingface/transformers#34379" reason="Model broken in HF, see huggingface/transformers#34379"
) ),
] ]
), ),
"fuyu": VLMTestInfo( "fuyu": VLMTestInfo(
@ -185,8 +185,8 @@ VLM_TEST_SETTINGS = {
max_num_seqs=2, max_num_seqs=2,
dtype="bfloat16", dtype="bfloat16",
get_stop_token_ids=lambda tok: [151329, 151336, 151338], get_stop_token_ids=lambda tok: [151329, 151336, 151338],
marks=[large_gpu_mark(min_gb=48)],
patch_hf_runner=model_utils.glm_patch_hf_runner, patch_hf_runner=model_utils.glm_patch_hf_runner,
marks=[large_gpu_mark(min_gb=48)],
), ),
"h2ovl": VLMTestInfo( "h2ovl": VLMTestInfo(
models = [ models = [
@ -205,6 +205,22 @@ VLM_TEST_SETTINGS = {
use_tokenizer_eos=True, use_tokenizer_eos=True,
patch_hf_runner=model_utils.h2ovl_patch_hf_runner, patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
), ),
"idefics3": VLMTestInfo(
models=["HuggingFaceM4/Idefics3-8B-Llama3"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
img_idx_to_prompt=lambda idx: "<image>",
max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
marks=[
pytest.mark.skipif(
transformers.__version__ < "4.46.0",
reason="Model introduced in HF >= 4.46.0"
),
large_gpu_mark(min_gb=48),
],
),
"intern_vl": VLMTestInfo( "intern_vl": VLMTestInfo(
models=[ models=[
"OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL2-1B",
@ -263,7 +279,6 @@ VLM_TEST_SETTINGS = {
runner_mm_key="videos", runner_mm_key="videos",
)], )],
), ),
# FIXME
"llava_next_video": VLMTestInfo( "llava_next_video": VLMTestInfo(
models=["llava-hf/LLaVA-NeXT-Video-7B-hf"], models=["llava-hf/LLaVA-NeXT-Video-7B-hf"],
test_type=VLMTestType.VIDEO, test_type=VLMTestType.VIDEO,
@ -275,7 +290,7 @@ VLM_TEST_SETTINGS = {
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))], image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
marks=[ marks=[
pytest.mark.skipif( pytest.mark.skipif(
transformers.__version__.startswith("4.46"), transformers.__version__ < "4.46.2",
reason="Model broken with changes in transformers 4.46" reason="Model broken with changes in transformers 4.46"
) )
], ],
@ -316,6 +331,7 @@ VLM_TEST_SETTINGS = {
max_model_len=8192, max_model_len=8192,
max_num_seqs=2, max_num_seqs=2,
auto_cls=AutoModelForVision2Seq, auto_cls=AutoModelForVision2Seq,
marks=[large_gpu_mark(min_gb=48)],
), ),
"qwen": VLMTestInfo( "qwen": VLMTestInfo(
models=["Qwen/Qwen-VL"], models=["Qwen/Qwen-VL"],
@ -327,22 +343,6 @@ VLM_TEST_SETTINGS = {
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output, vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
prompt_path_encoder=model_utils.qwen_prompt_path_encoder, prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
), ),
"idefics3": VLMTestInfo(
models=["HuggingFaceM4/Idefics3-8B-Llama3"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
img_idx_to_prompt=lambda idx: "<image>",
max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
marks=[
pytest.mark.skipif(
transformers.__version__ < "4.46.0",
reason="Model introduced in HF >= 4.46.0"
),
large_gpu_mark(min_gb=48),
],
),
### Tensor parallel / multi-gpu broadcast tests ### Tensor parallel / multi-gpu broadcast tests
"broadcast-chameleon": VLMTestInfo( "broadcast-chameleon": VLMTestInfo(
models=["facebook/chameleon-7b"], models=["facebook/chameleon-7b"],
@ -362,7 +362,7 @@ VLM_TEST_SETTINGS = {
reason="Need at least 2 GPUs to run the test.", reason="Need at least 2 GPUs to run the test.",
), ),
pytest.mark.skipif( pytest.mark.skipif(
transformers.__version__.startswith("4.46"), transformers.__version__ < "4.46.2",
reason="Model broken in HF, see huggingface/transformers#34379" reason="Model broken in HF, see huggingface/transformers#34379"
) )
], ],

View File

@ -1,7 +1,8 @@
import copy
import enum import enum
import json import json
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field, replace
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal, from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
Mapping, Optional, Set, Tuple, Type, Union) Mapping, Optional, Set, Tuple, Type, Union)
@ -2078,6 +2079,12 @@ class VllmConfig:
return quant_config return quant_config
return None return None
def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig":
model_config = copy.deepcopy(self.model_config)
model_config.hf_config = hf_config
return replace(self, model_config=model_config)
def __post_init__(self): def __post_init__(self):
"""Verify configs are valid & consistent with each other. """Verify configs are valid & consistent with each other.
""" """

View File

@ -229,7 +229,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
@ -246,9 +245,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
quant_config=quant_config, quant_config=quant_config,
gather_output=True, gather_output=True,
) )
self.language_model = PersimmonForCausalLM(config.text_config, self.language_model = PersimmonForCausalLM(
cache_config=cache_config, vllm_config.with_hf_config(config.text_config))
quant_config=quant_config)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)

View File

@ -164,10 +164,12 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM):
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__(vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
super().__init__(config, cache_config, quant_config)
self.model = InternLM2VEModel(config, self.model = InternLM2VEModel(config,
cache_config, cache_config,
quant_config, quant_config,

View File

@ -241,11 +241,11 @@ def init_vllm_registered_model(
based on the arguments passed to the outer vLLM model. based on the arguments passed to the outer vLLM model.
""" """
model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures) model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
import copy
copied_config = copy.deepcopy(vllm_config)
copied_config.model_config.hf_config = hf_config
return model_class(vllm_config=copied_config, prefix=prefix) return model_class(
vllm_config=vllm_config.with_hf_config(hf_config),
prefix=prefix,
)
@overload @overload