vllm/tests/models/embedding/language/test_cls_models.py
Robert Shaw d4d93db2c5
[V1] V1 Enablement Oracle (#13726)
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
2025-03-14 22:02:20 -07:00

53 lines
1.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
"""Compare the classification outputs of HF and vLLM models.
Run `pytest tests/models/test_cls_models.py`.
"""
import pytest
import torch
from transformers import AutoModelForSequenceClassification
from vllm.platforms import current_platform
@pytest.mark.parametrize(
"model",
[
pytest.param("jason9693/Qwen2.5-1.5B-apeach",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
],
)
@pytest.mark.parametrize("dtype",
["half"] if current_platform.is_rocm() else ["float"])
def test_classification_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
monkeypatch,
) -> None:
if current_platform.is_rocm():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)
with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)
# check logits difference
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)
# the tolerance value of 1e-2 is selected based on the
# half datatype tests in
# tests/models/embedding/language/test_embedding.py
assert torch.allclose(hf_output, vllm_output,
1e-3 if dtype == "float" else 1e-2)