vllm/tests/models/test_mistral.py

45 lines
1.2 KiB
Python
Raw Normal View History

"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
Run `pytest tests/models/test_mistral.py`.
"""
import pytest
from .utils import check_logprobs_close
2024-05-09 00:44:35 +09:00
MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
2024-05-09 00:44:35 +09:00
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
2024-05-09 00:44:35 +09:00
example_prompts,
model: str,
dtype: str,
max_tokens: int,
2024-05-09 00:44:35 +09:00
num_logprobs: int,
) -> None:
2024-05-09 00:44:35 +09:00
# TODO(sang): Sliding window should be tested separately.
hf_model = hf_runner(model, dtype=dtype)
2024-05-09 00:44:35 +09:00
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
2024-05-09 00:44:35 +09:00
vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)
del vllm_model
2024-05-09 00:44:35 +09:00
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)