[Bugfix] Fix and reorganize broken GGUF tests and bump gguf version (#16194)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
b99733d092
commit
f6b32efb7f
@ -28,7 +28,7 @@ filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/31
|
|||||||
partial-json-parser # used for parsing partial JSON outputs
|
partial-json-parser # used for parsing partial JSON outputs
|
||||||
pyzmq
|
pyzmq
|
||||||
msgspec
|
msgspec
|
||||||
gguf == 0.10.0
|
gguf >= 0.13.0
|
||||||
importlib_metadata
|
importlib_metadata
|
||||||
mistral_common[opencv] >= 1.5.4
|
mistral_common[opencv] >= 1.5.4
|
||||||
opencv-python-headless >= 4.11.0 # required for video IO
|
opencv-python-headless >= 4.11.0 # required for video IO
|
||||||
|
@ -9,11 +9,13 @@ from typing import NamedTuple
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from pytest import MarkDecorator
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
|
|
||||||
from ....conftest import VllmRunner
|
from ....conftest import VllmRunner
|
||||||
|
from ....utils import multi_gpu_test
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||||
@ -25,6 +27,7 @@ class GGUFTestConfig(NamedTuple):
|
|||||||
original_model: str
|
original_model: str
|
||||||
gguf_repo: str
|
gguf_repo: str
|
||||||
gguf_filename: str
|
gguf_filename: str
|
||||||
|
marks: list[MarkDecorator] = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def gguf_model(self):
|
def gguf_model(self):
|
||||||
@ -35,6 +38,7 @@ LLAMA_CONFIG = GGUFTestConfig(
|
|||||||
original_model="meta-llama/Llama-3.2-1B-Instruct",
|
original_model="meta-llama/Llama-3.2-1B-Instruct",
|
||||||
gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF",
|
gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||||
gguf_filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf",
|
gguf_filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf",
|
||||||
|
marks=[pytest.mark.quant_model],
|
||||||
)
|
)
|
||||||
|
|
||||||
QWEN2_CONFIG = GGUFTestConfig(
|
QWEN2_CONFIG = GGUFTestConfig(
|
||||||
@ -81,34 +85,24 @@ MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
|
def check_model_outputs(
|
||||||
reason="gguf is not supported on this GPU type.")
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
|
||||||
@pytest.mark.parametrize("max_tokens", [32])
|
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
|
||||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
|
||||||
def test_models(
|
|
||||||
num_gpus_available: int,
|
|
||||||
vllm_runner: type[VllmRunner],
|
vllm_runner: type[VllmRunner],
|
||||||
example_prompts: list[str],
|
prompts: list[str],
|
||||||
model: GGUFTestConfig,
|
model: GGUFTestConfig,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
) -> None:
|
):
|
||||||
if num_gpus_available < tp_size:
|
|
||||||
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model.original_model)
|
tokenizer = AutoTokenizer.from_pretrained(model.original_model)
|
||||||
if tokenizer.chat_template is not None:
|
if tokenizer.chat_template is not None:
|
||||||
messages = [[{
|
messages = [[{
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
'content': prompt
|
'content': prompt
|
||||||
}] for prompt in example_prompts]
|
}] for prompt in prompts]
|
||||||
example_prompts = tokenizer.apply_chat_template(
|
prompts = tokenizer.apply_chat_template(messages,
|
||||||
messages, tokenize=False, add_generation_prompt=True)
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
|
|
||||||
# Run gguf model.
|
# Run gguf model.
|
||||||
with vllm_runner(model_name=model.gguf_model,
|
with vllm_runner(model_name=model.gguf_model,
|
||||||
@ -118,17 +112,19 @@ def test_models(
|
|||||||
max_model_len=MAX_MODEL_LEN,
|
max_model_len=MAX_MODEL_LEN,
|
||||||
tensor_parallel_size=tp_size) as gguf_model:
|
tensor_parallel_size=tp_size) as gguf_model:
|
||||||
gguf_outputs = gguf_model.generate_greedy_logprobs(
|
gguf_outputs = gguf_model.generate_greedy_logprobs(
|
||||||
example_prompts[:-1], max_tokens, num_logprobs)
|
prompts[:-1], max_tokens, num_logprobs)
|
||||||
|
|
||||||
# Run unquantized model.
|
# Run unquantized model.
|
||||||
|
# Should run with tp=1, otherwise the test will stuck at
|
||||||
|
# nccl initialization.
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model_name=model.original_model,
|
model_name=model.original_model,
|
||||||
enforce_eager=True, # faster tests
|
enforce_eager=True, # faster tests
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_model_len=MAX_MODEL_LEN,
|
max_model_len=MAX_MODEL_LEN,
|
||||||
tensor_parallel_size=tp_size) as original_model:
|
tensor_parallel_size=1) as original_model:
|
||||||
original_outputs = original_model.generate_greedy_logprobs(
|
original_outputs = original_model.generate_greedy_logprobs(
|
||||||
example_prompts[:-1], max_tokens, num_logprobs)
|
prompts[:-1], max_tokens, num_logprobs)
|
||||||
|
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
outputs_0_lst=original_outputs,
|
outputs_0_lst=original_outputs,
|
||||||
@ -136,3 +132,47 @@ def test_models(
|
|||||||
name_0="original",
|
name_0="original",
|
||||||
name_1="gguf",
|
name_1="gguf",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
|
||||||
|
reason="gguf is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("model", [
|
||||||
|
pytest.param(test_config, marks=test_config.marks)
|
||||||
|
for test_config in MODELS
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
@pytest.mark.parametrize("tp_size", [1])
|
||||||
|
def test_models(
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
example_prompts: list[str],
|
||||||
|
model: GGUFTestConfig,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tp_size: int,
|
||||||
|
) -> None:
|
||||||
|
check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens,
|
||||||
|
num_logprobs, tp_size)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
|
||||||
|
reason="gguf is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("model", [LLAMA_CONFIG])
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [8])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
@pytest.mark.parametrize("tp_size", [2])
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
def test_distributed(
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
example_prompts: list[str],
|
||||||
|
model: GGUFTestConfig,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tp_size: int,
|
||||||
|
) -> None:
|
||||||
|
check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens,
|
||||||
|
num_logprobs, tp_size)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user