2024-04-29 12:35:34 -04:00
|
|
|
"""Compares the outputs of gptq vs gptq_marlin
|
|
|
|
Note: GPTQ and Marlin do not have bitwise correctness.
|
|
|
|
As a result, in this test, we just confirm that the top selected tokens of the
|
2024-05-12 20:46:31 -04:00
|
|
|
Marlin/GPTQ models are in the top 5 selections of each other.
|
2024-04-29 12:35:34 -04:00
|
|
|
Note: Marlin internally uses locks to synchronize the threads. This can
|
|
|
|
result in very slight nondeterminism for Marlin. As a result, we re-run the test
|
|
|
|
up to 3 times to see if we pass.
|
2024-05-12 20:46:31 -04:00
|
|
|
|
2024-04-29 12:35:34 -04:00
|
|
|
Run `pytest tests/models/test_gptq_marlin.py`.
|
|
|
|
"""
|
|
|
|
import os
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
2024-06-13 11:18:08 -04:00
|
|
|
from tests.quantization.utils import is_quant_method_supported
|
2024-05-16 21:55:29 +08:00
|
|
|
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT
|
2024-04-29 12:35:34 -04:00
|
|
|
|
2024-09-14 01:20:06 +08:00
|
|
|
from ...utils import check_logprobs_close
|
2024-05-13 22:50:09 +08:00
|
|
|
|
2024-04-29 12:35:34 -04:00
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
|
|
|
|
|
|
MAX_MODEL_LEN = 1024
|
|
|
|
|
|
|
|
MODELS = [
|
|
|
|
# act_order==True, group_size=128
|
|
|
|
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "main"),
|
2024-05-02 12:56:22 -04:00
|
|
|
|
|
|
|
# 8-bit, act_order==True, group_size=channelwise
|
|
|
|
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True"),
|
2024-06-15 13:38:16 -04:00
|
|
|
|
|
|
|
# 4-bit, act_order==True, group_size=128
|
|
|
|
("TechxGenus/gemma-1.1-2b-it-GPTQ", "main")
|
2024-04-29 12:35:34 -04:00
|
|
|
]
|
|
|
|
|
|
|
|
|
2024-11-10 03:39:14 +08:00
|
|
|
@pytest.mark.quant_model
|
2024-05-12 20:46:31 -04:00
|
|
|
@pytest.mark.flaky(reruns=3)
|
2024-06-13 11:18:08 -04:00
|
|
|
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
2024-04-29 12:35:34 -04:00
|
|
|
reason="gptq_marlin is not supported on this GPU type.")
|
|
|
|
@pytest.mark.parametrize("model", MODELS)
|
2024-05-16 21:55:29 +08:00
|
|
|
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])
|
2024-04-29 12:35:34 -04:00
|
|
|
@pytest.mark.parametrize("max_tokens", [32])
|
|
|
|
@pytest.mark.parametrize("num_logprobs", [5])
|
|
|
|
def test_models(
|
|
|
|
vllm_runner,
|
|
|
|
example_prompts,
|
|
|
|
model,
|
|
|
|
dtype: str,
|
|
|
|
max_tokens: int,
|
|
|
|
num_logprobs: int,
|
|
|
|
) -> None:
|
|
|
|
model_name, revision = model
|
|
|
|
|
|
|
|
# Run marlin.
|
2024-06-08 01:59:20 -07:00
|
|
|
with vllm_runner(model_name=model_name,
|
|
|
|
revision=revision,
|
|
|
|
dtype=dtype,
|
|
|
|
quantization="marlin",
|
|
|
|
max_model_len=MAX_MODEL_LEN,
|
|
|
|
tensor_parallel_size=1) as gptq_marlin_model:
|
|
|
|
|
|
|
|
gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
|
|
|
|
example_prompts[:-1], max_tokens, num_logprobs)
|
2024-05-16 21:55:29 +08:00
|
|
|
_ROPE_DICT.clear() # clear rope cache to avoid rope dtype error
|
2024-04-29 12:35:34 -04:00
|
|
|
|
|
|
|
# Run gptq.
|
2024-05-16 21:55:29 +08:00
|
|
|
# The naive gptq kernel doesn't support bf16 yet.
|
|
|
|
# Here we always compare fp16/bf16 gpt marlin kernel
|
|
|
|
# to fp16 gptq kernel.
|
2024-06-08 01:59:20 -07:00
|
|
|
with vllm_runner(model_name=model_name,
|
|
|
|
revision=revision,
|
|
|
|
dtype="half",
|
|
|
|
quantization="gptq",
|
|
|
|
max_model_len=MAX_MODEL_LEN,
|
|
|
|
tensor_parallel_size=1) as gptq_model:
|
|
|
|
gptq_outputs = gptq_model.generate_greedy_logprobs(
|
|
|
|
example_prompts[:-1], max_tokens, num_logprobs)
|
2024-04-29 12:35:34 -04:00
|
|
|
|
|
|
|
check_logprobs_close(
|
|
|
|
outputs_0_lst=gptq_outputs,
|
|
|
|
outputs_1_lst=gptq_marlin_outputs,
|
|
|
|
name_0="gptq",
|
|
|
|
name_1="gptq_marlin",
|
|
|
|
)
|