vllm/tests/models/test_mistral.py
Cyrus Leung 350f9e107f
[CI/Build] Move test_utils.py to tests/utils.py (#4425)
Since #4335 was merged, I've noticed that the definition of ServerRunner in the tests is the same as in the test for OpenAI API. I have moved the class to the test utilities to avoid code duplication. (Although it only has been repeated twice so far, I will add another similar test suite in #4200 which would duplicate the code a third time)

Also, I have moved the test utilities file (test_utils.py) to under the test directory (tests/utils.py), since none of its code is actually used in the main package. Note that I have added __init__.py to each test subpackage and updated the ray.init() call in the test utilities file in order to relative import tests/utils.py.
2024-05-13 23:50:09 +09:00

44 lines
1.2 KiB
Python

"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
Run `pytest tests/models/test_mistral.py`.
"""
import pytest
from .utils import check_logprobs_close
MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
# TODO(sang): Sliding window should be tested separately.
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)
del vllm_model
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)