[CI/Build] Split up models tests (#10069)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
b09895a618
commit
51c2e1fcef
@ -305,7 +305,7 @@ steps:
|
||||
|
||||
##### models test #####
|
||||
|
||||
- label: Basic Models Test # 3min
|
||||
- label: Basic Models Test # 10min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models
|
||||
@ -314,23 +314,24 @@ steps:
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models/*.py --ignore=models/test_oot_registration.py
|
||||
|
||||
- label: Decoder-only Language Models Test (Standard) # 35min
|
||||
- label: Decoder-only Language Models Test (Standard) # 18min
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/decoder_only/language
|
||||
commands:
|
||||
- pytest -v -s models/decoder_only/language/test_models.py
|
||||
- pytest -v -s models/decoder_only/language -m core_model
|
||||
- pytest -v -s models/decoder_only/language -m quant_model
|
||||
|
||||
- label: Decoder-only Language Models Test (Extended) # 1h20min
|
||||
- label: Decoder-only Language Models Test (Extended) # 46min
|
||||
nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/decoder_only/language
|
||||
commands:
|
||||
- pytest -v -s models/decoder_only/language --ignore=models/decoder_only/language/test_models.py
|
||||
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
|
||||
|
||||
- label: Decoder-only Multi-Modal Models Test (Standard) # 26min
|
||||
- label: Decoder-only Multi-Modal Models Test (Standard) # 22min
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -339,21 +340,24 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s models/decoder_only/audio_language -m core_model
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m core_model
|
||||
# No tests under this group for now
|
||||
# - pytest -v -s models/decoder_only/audio_language -m quant_model
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m quant_model
|
||||
|
||||
- label: Decoder-only Multi-Modal Models Test (Extended)
|
||||
- label: Decoder-only Multi-Modal Models Test (Extended) # 1h10m
|
||||
nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/decoder_only/audio_language
|
||||
- tests/models/decoder_only/vision_language
|
||||
commands:
|
||||
- pytest -v -s models/decoder_only/audio_language -m 'not core_model'
|
||||
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
|
||||
# HACK - run phi3v tests separately to sidestep this transformers bug
|
||||
# https://github.com/huggingface/transformers/issues/34307
|
||||
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model'
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
|
||||
|
||||
- label: Other Models Test # 6min
|
||||
- label: Other Models Test # 20min
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
|
@ -95,6 +95,7 @@ markers = [
|
||||
"skip_global_cleanup",
|
||||
"core_model: enable this model test in each PR instead of only nightly",
|
||||
"cpu_model: enable this model test in CPU tests",
|
||||
"quant_model: run this model test under Quantized category",
|
||||
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs",
|
||||
"skip_v1: do not run this test with v1",
|
||||
]
|
||||
|
@ -38,6 +38,7 @@ ground_truth_generations = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.quant_model
|
||||
@pytest.mark.skipif(not is_quant_method_supported("aqlm"),
|
||||
reason="AQLM is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])
|
||||
|
@ -15,6 +15,7 @@ from ...utils import check_logprobs_close
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
|
||||
@pytest.mark.quant_model
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="fp8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -17,26 +17,21 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
|
||||
# FIXME: Move this to confest
|
||||
MODELS = [
|
||||
("meta-llama/Llama-3.2-1B-Instruct",
|
||||
hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||
filename="Llama-3.2-1B-Instruct-Q4_K_M.gguf")),
|
||||
("meta-llama/Llama-3.2-1B-Instruct",
|
||||
hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||
filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf")),
|
||||
("Qwen/Qwen2-1.5B-Instruct",
|
||||
hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF",
|
||||
filename="qwen2-1_5b-instruct-q4_k_m.gguf")),
|
||||
("Qwen/Qwen2-1.5B-Instruct",
|
||||
hf_hub_download("legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
|
||||
filename="Qwen2-1.5B-Instruct.IQ4_XS.gguf")),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
|
||||
reason="gguf is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize(("original_model", "gguf_id", "gguf_path"), [
|
||||
("meta-llama/Llama-3.2-1B-Instruct",
|
||||
"bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||
"Llama-3.2-1B-Instruct-Q4_K_M.gguf"),
|
||||
("meta-llama/Llama-3.2-1B-Instruct",
|
||||
"bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||
"Llama-3.2-1B-Instruct-IQ4_XS.gguf"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-1.5B-Instruct-GGUF",
|
||||
"qwen2-1_5b-instruct-q4_k_m.gguf"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
|
||||
"Qwen2-1.5B-Instruct.IQ4_XS.gguf"),
|
||||
])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@ -45,7 +40,9 @@ def test_models(
|
||||
num_gpus_available,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model,
|
||||
original_model,
|
||||
gguf_id,
|
||||
gguf_path,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
@ -54,7 +51,7 @@ def test_models(
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
|
||||
original_model, gguf_model = model
|
||||
gguf_model = hf_hub_download(gguf_id, filename=gguf_path)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(original_model)
|
||||
messages = [[{
|
||||
|
@ -33,6 +33,7 @@ MODELS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.quant_model
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="gptq_marlin is not supported on this GPU type.")
|
||||
|
@ -38,6 +38,7 @@ model_pairs = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.quant_model
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24"),
|
||||
reason="Marlin24 is not supported on this GPU type.")
|
||||
|
@ -7,7 +7,9 @@ import pytest
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
MODELS = [
|
||||
# TODO(sang): Sliding window should be tested separately.
|
||||
"ibm/PowerLM-3b",
|
||||
"ibm/PowerMoE-3b",
|
||||
]
|
||||
|
||||
|
||||
@ -24,7 +26,6 @@ def test_models(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
# TODO(sang): Sliding window should be tested separately.
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
@ -1,39 +0,0 @@
|
||||
"""Compare the outputs of HF and vLLM for Granite models using greedy sampling.
|
||||
|
||||
Run `pytest tests/models/test_granite.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
MODELS = [
|
||||
"ibm/PowerMoE-3b",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
@ -39,6 +39,7 @@ EXPECTED_STRS_MAP = {
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Prevent unstable test based on golden strings from breaking the build.")
|
||||
@pytest.mark.quant_model
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="fp8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model_name", MODELS)
|
||||
|
@ -1,8 +1,5 @@
|
||||
"""Compare the outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
This test only tests small models. Big models such as 7B should be tested from
|
||||
test_big_models.py because it could use a larger instance to run tests.
|
||||
|
||||
Run `pytest tests/models/test_models.py`.
|
||||
"""
|
||||
import pytest
|
||||
@ -35,6 +32,7 @@ if not current_platform.is_cpu():
|
||||
target_dtype = "half"
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
|
@ -56,11 +56,13 @@ def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next,
|
||||
ctx.model_config.hf_config.image_grid_pinpoints = gridpoints
|
||||
seq_len = 5000 # bigger than the max feature size for any image
|
||||
|
||||
seq_data, mm_data = dummy_data_for_llava_next(
|
||||
dummy_data = dummy_data_for_llava_next(
|
||||
ctx,
|
||||
seq_len=seq_len,
|
||||
mm_counts={"image": 1},
|
||||
)
|
||||
seq_data = dummy_data.seq_data
|
||||
mm_data = dummy_data.multi_modal_data
|
||||
|
||||
# The dummy data dims should match the gridpoint with the biggest feat size
|
||||
assert mm_data["image"].height == expected_size[0]
|
||||
|
@ -131,12 +131,13 @@ def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
|
||||
sequence_data, _, = dummy_data_for_phi3v(
|
||||
dummy_data = dummy_data_for_phi3v(
|
||||
ctx=ctx,
|
||||
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
|
||||
mm_counts={"image": num_imgs},
|
||||
num_crops=num_crops,
|
||||
)
|
||||
sequence_data = dummy_data.seq_data
|
||||
# Ensure we have the right number of placeholders per num_crops size
|
||||
img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID)
|
||||
assert img_tok_count == toks_per_img * num_imgs
|
||||
|
@ -86,10 +86,17 @@ def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl,
|
||||
|
||||
# NOTE: video value is required, but isn't actually used
|
||||
# when making the dummy data except for error handling currently
|
||||
seq_data, mm_data = dummy_data_for_qwen2_vl(qwen2_vl_context, seq_len, {
|
||||
"image": 1,
|
||||
"video": 0
|
||||
}, **mm_processor_kwargs)
|
||||
dummy_data = dummy_data_for_qwen2_vl(
|
||||
ctx=qwen2_vl_context,
|
||||
seq_len=seq_len,
|
||||
mm_counts={
|
||||
"image": 1,
|
||||
"video": 0
|
||||
},
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
seq_data = dummy_data.seq_data
|
||||
mm_data = dummy_data.multi_modal_data
|
||||
|
||||
# Ensure we have the right number of placeholders for min/max pixel values
|
||||
assert seq_data.get_token_ids().count(image_token_id) == token_count
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -19,7 +19,8 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
def run_awq_test(
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
models: Tuple[str, str],
|
||||
source_model: str,
|
||||
quant_model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
@ -28,8 +29,6 @@ def run_awq_test(
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
source_model, quant_model = models
|
||||
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
@ -84,8 +83,11 @@ def run_awq_test(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.quant_model
|
||||
@pytest.mark.parametrize(
|
||||
"models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")])
|
||||
("source_model", "quant_model"),
|
||||
[("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
@ -103,12 +105,13 @@ def run_awq_test(
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@torch.inference_mode()
|
||||
def test_awq_models(vllm_runner, image_assets, models, size_factors,
|
||||
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||
def test_awq_models(vllm_runner, image_assets, source_model, quant_model,
|
||||
size_factors, dtype, max_tokens, num_logprobs) -> None:
|
||||
run_awq_test(
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
models,
|
||||
source_model,
|
||||
quant_model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
@ -11,21 +11,17 @@ from ....conftest import _ImageAssets
|
||||
# we use snapshot_download to prevent conflicts between
|
||||
# dynamic_module and trust_remote_code for hf_runner
|
||||
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
|
||||
models = [
|
||||
snapshot_download("OpenGVLab/InternViT-300M-448px",
|
||||
allow_patterns=DOWNLOAD_PATTERN),
|
||||
snapshot_download("OpenGVLab/InternViT-6B-448px-V1-5",
|
||||
allow_patterns=DOWNLOAD_PATTERN),
|
||||
]
|
||||
|
||||
|
||||
def run_intern_vit_test(
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
model_id: str,
|
||||
*,
|
||||
dtype: str,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
|
||||
|
||||
img_processor = CLIPImageProcessor.from_pretrained(model)
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
pixel_values = [
|
||||
@ -67,12 +63,15 @@ def run_intern_vit_test(
|
||||
assert cos_similar(vllm_output, hf_output).mean() > 0.99
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("model_id", [
|
||||
"OpenGVLab/InternViT-300M-448px",
|
||||
"OpenGVLab/InternViT-6B-448px-V1-5",
|
||||
])
|
||||
@pytest.mark.parametrize("dtype", [torch.half])
|
||||
@torch.inference_mode()
|
||||
def test_models(dist_init, image_assets, model, dtype: str) -> None:
|
||||
def test_models(dist_init, image_assets, model_id, dtype: str) -> None:
|
||||
run_intern_vit_test(
|
||||
image_assets,
|
||||
model,
|
||||
model_id,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -130,8 +130,8 @@ VLM_TEST_SETTINGS = {
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
),
|
||||
#### Extended model tests
|
||||
"blip2": VLMTestInfo(
|
||||
@ -159,9 +159,9 @@ VLM_TEST_SETTINGS = {
|
||||
dtype="bfloat16",
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
transformers.__version__.startswith("4.46"),
|
||||
transformers.__version__ < "4.46.2",
|
||||
reason="Model broken in HF, see huggingface/transformers#34379"
|
||||
)
|
||||
),
|
||||
]
|
||||
),
|
||||
"fuyu": VLMTestInfo(
|
||||
@ -185,8 +185,8 @@ VLM_TEST_SETTINGS = {
|
||||
max_num_seqs=2,
|
||||
dtype="bfloat16",
|
||||
get_stop_token_ids=lambda tok: [151329, 151336, 151338],
|
||||
marks=[large_gpu_mark(min_gb=48)],
|
||||
patch_hf_runner=model_utils.glm_patch_hf_runner,
|
||||
marks=[large_gpu_mark(min_gb=48)],
|
||||
),
|
||||
"h2ovl": VLMTestInfo(
|
||||
models = [
|
||||
@ -205,6 +205,22 @@ VLM_TEST_SETTINGS = {
|
||||
use_tokenizer_eos=True,
|
||||
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
|
||||
),
|
||||
"idefics3": VLMTestInfo(
|
||||
models=["HuggingFaceM4/Idefics3-8B-Llama3"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<image>",
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
transformers.__version__ < "4.46.0",
|
||||
reason="Model introduced in HF >= 4.46.0"
|
||||
),
|
||||
large_gpu_mark(min_gb=48),
|
||||
],
|
||||
),
|
||||
"intern_vl": VLMTestInfo(
|
||||
models=[
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
@ -263,7 +279,6 @@ VLM_TEST_SETTINGS = {
|
||||
runner_mm_key="videos",
|
||||
)],
|
||||
),
|
||||
# FIXME
|
||||
"llava_next_video": VLMTestInfo(
|
||||
models=["llava-hf/LLaVA-NeXT-Video-7B-hf"],
|
||||
test_type=VLMTestType.VIDEO,
|
||||
@ -275,7 +290,7 @@ VLM_TEST_SETTINGS = {
|
||||
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
transformers.__version__.startswith("4.46"),
|
||||
transformers.__version__ < "4.46.2",
|
||||
reason="Model broken with changes in transformers 4.46"
|
||||
)
|
||||
],
|
||||
@ -316,6 +331,7 @@ VLM_TEST_SETTINGS = {
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
marks=[large_gpu_mark(min_gb=48)],
|
||||
),
|
||||
"qwen": VLMTestInfo(
|
||||
models=["Qwen/Qwen-VL"],
|
||||
@ -327,22 +343,6 @@ VLM_TEST_SETTINGS = {
|
||||
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
|
||||
prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
|
||||
),
|
||||
"idefics3": VLMTestInfo(
|
||||
models=["HuggingFaceM4/Idefics3-8B-Llama3"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<image>",
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
transformers.__version__ < "4.46.0",
|
||||
reason="Model introduced in HF >= 4.46.0"
|
||||
),
|
||||
large_gpu_mark(min_gb=48),
|
||||
],
|
||||
),
|
||||
### Tensor parallel / multi-gpu broadcast tests
|
||||
"broadcast-chameleon": VLMTestInfo(
|
||||
models=["facebook/chameleon-7b"],
|
||||
@ -362,7 +362,7 @@ VLM_TEST_SETTINGS = {
|
||||
reason="Need at least 2 GPUs to run the test.",
|
||||
),
|
||||
pytest.mark.skipif(
|
||||
transformers.__version__.startswith("4.46"),
|
||||
transformers.__version__ < "4.46.2",
|
||||
reason="Model broken in HF, see huggingface/transformers#34379"
|
||||
)
|
||||
],
|
||||
|
@ -1,7 +1,8 @@
|
||||
import copy
|
||||
import enum
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
|
||||
Mapping, Optional, Set, Tuple, Type, Union)
|
||||
|
||||
@ -2078,6 +2079,12 @@ class VllmConfig:
|
||||
return quant_config
|
||||
return None
|
||||
|
||||
def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig":
|
||||
model_config = copy.deepcopy(self.model_config)
|
||||
model_config.hf_config = hf_config
|
||||
|
||||
return replace(self, model_config=model_config)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Verify configs are valid & consistent with each other.
|
||||
"""
|
||||
|
@ -229,7 +229,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
@ -246,9 +245,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
quant_config=quant_config,
|
||||
gather_output=True,
|
||||
)
|
||||
self.language_model = PersimmonForCausalLM(config.text_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
self.language_model = PersimmonForCausalLM(
|
||||
vllm_config.with_hf_config(config.text_config))
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
|
@ -164,10 +164,12 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM):
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(vllm_config, prefix=prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
super().__init__(config, cache_config, quant_config)
|
||||
|
||||
self.model = InternLM2VEModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
|
@ -241,11 +241,11 @@ def init_vllm_registered_model(
|
||||
based on the arguments passed to the outer vLLM model.
|
||||
"""
|
||||
model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
|
||||
import copy
|
||||
copied_config = copy.deepcopy(vllm_config)
|
||||
copied_config.model_config.hf_config = hf_config
|
||||
|
||||
return model_class(vllm_config=copied_config, prefix=prefix)
|
||||
return model_class(
|
||||
vllm_config=vllm_config.with_hf_config(hf_config),
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
|
Loading…
x
Reference in New Issue
Block a user