[Bugfix] Fix and reorganize broken GGUF tests and bump gguf version (#16194)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-04-08 13:38:13 +08:00 committed by GitHub
parent b99733d092
commit f6b32efb7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 21 deletions

View File

@ -28,7 +28,7 @@ filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/31
partial-json-parser # used for parsing partial JSON outputs
pyzmq
msgspec
gguf == 0.10.0
gguf >= 0.13.0
importlib_metadata
mistral_common[opencv] >= 1.5.4
opencv-python-headless >= 4.11.0 # required for video IO

View File

@ -9,11 +9,13 @@ from typing import NamedTuple
import pytest
from huggingface_hub import hf_hub_download
from pytest import MarkDecorator
from transformers import AutoTokenizer
from tests.quantization.utils import is_quant_method_supported
from ....conftest import VllmRunner
from ....utils import multi_gpu_test
from ...utils import check_logprobs_close
os.environ["TOKENIZERS_PARALLELISM"] = "true"
@ -25,6 +27,7 @@ class GGUFTestConfig(NamedTuple):
original_model: str
gguf_repo: str
gguf_filename: str
marks: list[MarkDecorator] = []
@property
def gguf_model(self):
@ -35,6 +38,7 @@ LLAMA_CONFIG = GGUFTestConfig(
original_model="meta-llama/Llama-3.2-1B-Instruct",
gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF",
gguf_filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf",
marks=[pytest.mark.quant_model],
)
QWEN2_CONFIG = GGUFTestConfig(
@ -81,34 +85,24 @@ MODELS = [
]
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [1, 2])
def test_models(
num_gpus_available: int,
def check_model_outputs(
vllm_runner: type[VllmRunner],
example_prompts: list[str],
prompts: list[str],
model: GGUFTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
tp_size: int,
) -> None:
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
):
tokenizer = AutoTokenizer.from_pretrained(model.original_model)
if tokenizer.chat_template is not None:
messages = [[{
'role': 'user',
'content': prompt
}] for prompt in example_prompts]
example_prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
}] for prompt in prompts]
prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
# Run gguf model.
with vllm_runner(model_name=model.gguf_model,
@ -118,17 +112,19 @@ def test_models(
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tp_size) as gguf_model:
gguf_outputs = gguf_model.generate_greedy_logprobs(
example_prompts[:-1], max_tokens, num_logprobs)
prompts[:-1], max_tokens, num_logprobs)
# Run unquantized model.
# Should run with tp=1, otherwise the test will stuck at
# nccl initialization.
with vllm_runner(
model_name=model.original_model,
enforce_eager=True, # faster tests
dtype=dtype,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tp_size) as original_model:
tensor_parallel_size=1) as original_model:
original_outputs = original_model.generate_greedy_logprobs(
example_prompts[:-1], max_tokens, num_logprobs)
prompts[:-1], max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=original_outputs,
@ -136,3 +132,47 @@ def test_models(
name_0="original",
name_1="gguf",
)
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize("model", [
pytest.param(test_config, marks=test_config.marks)
for test_config in MODELS
])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [1])
def test_models(
vllm_runner: type[VllmRunner],
example_prompts: list[str],
model: GGUFTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
tp_size: int,
) -> None:
check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens,
num_logprobs, tp_size)
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize("model", [LLAMA_CONFIG])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [2])
@multi_gpu_test(num_gpus=2)
def test_distributed(
vllm_runner: type[VllmRunner],
example_prompts: list[str],
model: GGUFTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
tp_size: int,
) -> None:
check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens,
num_logprobs, tp_size)