[Bugfix] Fix and reorganize broken GGUF tests and bump gguf version (#16194)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
b99733d092
commit
f6b32efb7f
@ -28,7 +28,7 @@ filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/31
|
||||
partial-json-parser # used for parsing partial JSON outputs
|
||||
pyzmq
|
||||
msgspec
|
||||
gguf == 0.10.0
|
||||
gguf >= 0.13.0
|
||||
importlib_metadata
|
||||
mistral_common[opencv] >= 1.5.4
|
||||
opencv-python-headless >= 4.11.0 # required for video IO
|
||||
|
@ -9,11 +9,13 @@ from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pytest import MarkDecorator
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
from ....conftest import VllmRunner
|
||||
from ....utils import multi_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
@ -25,6 +27,7 @@ class GGUFTestConfig(NamedTuple):
|
||||
original_model: str
|
||||
gguf_repo: str
|
||||
gguf_filename: str
|
||||
marks: list[MarkDecorator] = []
|
||||
|
||||
@property
|
||||
def gguf_model(self):
|
||||
@ -35,6 +38,7 @@ LLAMA_CONFIG = GGUFTestConfig(
|
||||
original_model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||
gguf_filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf",
|
||||
marks=[pytest.mark.quant_model],
|
||||
)
|
||||
|
||||
QWEN2_CONFIG = GGUFTestConfig(
|
||||
@ -81,34 +85,24 @@ MODELS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
|
||||
reason="gguf is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
def test_models(
|
||||
num_gpus_available: int,
|
||||
def check_model_outputs(
|
||||
vllm_runner: type[VllmRunner],
|
||||
example_prompts: list[str],
|
||||
prompts: list[str],
|
||||
model: GGUFTestConfig,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tp_size: int,
|
||||
) -> None:
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model.original_model)
|
||||
if tokenizer.chat_template is not None:
|
||||
messages = [[{
|
||||
'role': 'user',
|
||||
'content': prompt
|
||||
}] for prompt in example_prompts]
|
||||
example_prompts = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True)
|
||||
}] for prompt in prompts]
|
||||
prompts = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
# Run gguf model.
|
||||
with vllm_runner(model_name=model.gguf_model,
|
||||
@ -118,17 +112,19 @@ def test_models(
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=tp_size) as gguf_model:
|
||||
gguf_outputs = gguf_model.generate_greedy_logprobs(
|
||||
example_prompts[:-1], max_tokens, num_logprobs)
|
||||
prompts[:-1], max_tokens, num_logprobs)
|
||||
|
||||
# Run unquantized model.
|
||||
# Should run with tp=1, otherwise the test will stuck at
|
||||
# nccl initialization.
|
||||
with vllm_runner(
|
||||
model_name=model.original_model,
|
||||
enforce_eager=True, # faster tests
|
||||
dtype=dtype,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=tp_size) as original_model:
|
||||
tensor_parallel_size=1) as original_model:
|
||||
original_outputs = original_model.generate_greedy_logprobs(
|
||||
example_prompts[:-1], max_tokens, num_logprobs)
|
||||
prompts[:-1], max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=original_outputs,
|
||||
@ -136,3 +132,47 @@ def test_models(
|
||||
name_0="original",
|
||||
name_1="gguf",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
|
||||
reason="gguf is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model", [
|
||||
pytest.param(test_config, marks=test_config.marks)
|
||||
for test_config in MODELS
|
||||
])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
def test_models(
|
||||
vllm_runner: type[VllmRunner],
|
||||
example_prompts: list[str],
|
||||
model: GGUFTestConfig,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tp_size: int,
|
||||
) -> None:
|
||||
check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens,
|
||||
num_logprobs, tp_size)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
|
||||
reason="gguf is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model", [LLAMA_CONFIG])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [8])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_distributed(
|
||||
vllm_runner: type[VllmRunner],
|
||||
example_prompts: list[str],
|
||||
model: GGUFTestConfig,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tp_size: int,
|
||||
) -> None:
|
||||
check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens,
|
||||
num_logprobs, tp_size)
|
||||
|
Loading…
x
Reference in New Issue
Block a user