2025-03-08 08:19:38 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
"""A basic correctness check for TPUs
|
|
|
|
|
|
|
|
Run `pytest tests/v1/tpu/test_basic.py`.
|
|
|
|
"""
|
2025-03-17 11:35:57 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
|
2025-03-08 08:19:38 -05:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
from vllm.platforms import current_platform
|
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
from tests.conftest import VllmRunner
|
2025-03-08 08:19:38 -05:00
|
|
|
|
|
|
|
MODELS = [
|
2025-03-18 17:39:21 -04:00
|
|
|
"Qwen/Qwen2.5-1.5B-Instruct",
|
|
|
|
# TODO: Enable this models with v6e
|
2025-03-08 08:19:38 -05:00
|
|
|
# "Qwen/Qwen2-7B-Instruct",
|
2025-03-18 17:39:21 -04:00
|
|
|
# "meta-llama/Llama-3.1-8B",
|
2025-03-08 08:19:38 -05:00
|
|
|
]
|
|
|
|
|
|
|
|
TENSOR_PARALLEL_SIZES = [1]
|
|
|
|
|
|
|
|
# TODO: Enable when CI/CD will have a multi-tpu instance
|
|
|
|
# TENSOR_PARALLEL_SIZES = [1, 4]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not current_platform.is_tpu(),
|
|
|
|
reason="This is a basic test for TPU only")
|
|
|
|
@pytest.mark.parametrize("model", MODELS)
|
|
|
|
@pytest.mark.parametrize("max_tokens", [5])
|
|
|
|
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
|
|
|
|
def test_models(
|
2025-03-17 11:35:57 +08:00
|
|
|
vllm_runner: type[VllmRunner],
|
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
2025-03-08 08:19:38 -05:00
|
|
|
model: str,
|
|
|
|
max_tokens: int,
|
|
|
|
tensor_parallel_size: int,
|
|
|
|
) -> None:
|
|
|
|
prompt = "The next numbers of the sequence " + ", ".join(
|
|
|
|
str(i) for i in range(1024)) + " are:"
|
|
|
|
example_prompts = [prompt]
|
|
|
|
|
|
|
|
with monkeypatch.context() as m:
|
|
|
|
m.setenv("VLLM_USE_V1", "1")
|
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
with vllm_runner(
|
2025-03-08 08:19:38 -05:00
|
|
|
model,
|
2025-03-28 16:19:04 -04:00
|
|
|
# Note: max_num_batched_tokens == 1024 is needed here to
|
|
|
|
# actually test chunked prompt
|
|
|
|
max_num_batched_tokens=1024,
|
|
|
|
max_model_len=8196,
|
2025-03-08 08:19:38 -05:00
|
|
|
gpu_memory_utilization=0.7,
|
|
|
|
max_num_seqs=16,
|
|
|
|
tensor_parallel_size=tensor_parallel_size) as vllm_model:
|
|
|
|
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
|
|
|
max_tokens)
|
2025-03-17 11:35:57 +08:00
|
|
|
output = vllm_outputs[0][1]
|
|
|
|
assert "1024" in output
|