[V1] TPU CI - Add basic perf regression test (#15414)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev 2025-03-31 13:25:20 -04:00 committed by GitHub
parent 2de4118243
commit 9a2160fa55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 192 additions and 20 deletions

View File

@ -21,6 +21,8 @@ docker run --privileged --net host --shm-size=16G -it \
&& python3 -m pip install lm_eval[api]==0.4.4 \
&& export VLLM_USE_V1=1 \
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
&& echo TEST_0 \
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \
&& echo TEST_1 \
&& pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \
&& echo TEST_2 \

View File

@ -58,7 +58,7 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
more_args = None
if current_platform.is_tpu():
# Limit compilation time for TPU V1
more_args = "max_num_seqs=64"
more_args = "max_model_len=2048,max_num_seqs=64"
# Add TP test (if provided)
if TPU_TP_TEST_STR:

View File

@ -32,7 +32,7 @@ TENSOR_PARALLEL_SIZES = [1]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
def test_models(
def test_basic(
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
model: str,
@ -58,4 +58,5 @@ def test_models(
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
output = vllm_outputs[0][1]
assert "1024" in output
assert "1024" in output or "0, 1" in output

146
tests/v1/tpu/test_perf.py Normal file
View File

@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
"""A basic performance regression test for TPUs
Run `pytest tests/v1/tpu/test_perf.py`.
"""
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
import pytest
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
if TYPE_CHECKING:
from tests.conftest import VllmRunner
@dataclass
class TestParams:
model: str
num_prompts: int
prefix_len: int
decode_len: int
expected_avg_time: float
err_tol: float
TEST_PARAMS = [
# TODO: Cannot run a series of tests because:
# RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed:
# open(/dev/vfio/0): Device or resource busy: Device or resource busy;
# Couldn't open iommu group /dev/vfio/0
# => Investigate
# TestParams(
# model="Qwen/Qwen2.5-1.5B-Instruct",
# num_prompts=1,
# prefix_len=10,
# decode_len=5,
# expected_avg_time=0.03,
# err_tol=0.01,
# ),
# TestParams(
# model="Qwen/Qwen2.5-1.5B-Instruct",
# num_prompts=10,
# prefix_len=100,
# decode_len=50,
# expected_avg_time=0.234,
# err_tol=0.020,
# ),
TestParams(
model="Qwen/Qwen2.5-1.5B-Instruct",
num_prompts=64,
prefix_len=500,
decode_len=50,
# (This is the active CI/CD instance)
# commit id: ccb246776d93ef105904a8ec015b3587240a1183
# tpu: v5lite (vllm CI/CD)
expected_avg_time=1.4,
err_tol=0.30,
# (TODO: There is no v6e in CI/CD currently)
# commit id: ccb246776d93ef105904a8ec015b3587240a1183
# tpu: v6e
# expected_avg_time=1.5,
# err_tol=0.20,
),
]
NUM_WARMUPS = 5
NUM_RUNS = 10
MAX_MODEL_LEN = 1024
MAX_NUM_SEQS = 32
GPU_UTIL = 0.9
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This is a basic performance test for TPU only")
@pytest.mark.parametrize("params", TEST_PARAMS)
def test_perf(
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
params: TestParams,
) -> None:
tokenizer = get_tokenizer(params.model,
tokenizer_mode="auto",
trust_remote_code=True)
prompts = []
for i in range(params.num_prompts):
prefix_token_ids = np.random.randint(0,
tokenizer.vocab_size,
size=params.prefix_len).tolist()
prompt = tokenizer.decode(prefix_token_ids)
prompts.append(prompt)
print(
"-- Running: num_prompts = {} prefix_len = {} decode_len = {}".format(
len(prompts), params.prefix_len, params.decode_len))
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
sampling_params = SamplingParams(max_tokens=params.decode_len,
temperature=1.0,
min_p=0.0)
with vllm_runner(params.model,
max_num_batched_tokens=MAX_MODEL_LEN,
max_model_len=MAX_MODEL_LEN,
max_num_seqs=MAX_NUM_SEQS,
gpu_memory_utilization=GPU_UTIL,
enforce_eager=False,
tensor_parallel_size=1) as vllm_model:
print(" -- Warmup / Compile")
for i in range(NUM_WARMUPS):
_ = vllm_model.generate(prompts, sampling_params)
print(" -- Benchmarking... ")
times = []
for i in range(NUM_RUNS):
start_time = time.time()
_ = vllm_model.generate(prompts, sampling_params)
times.append(time.time() - start_time)
avg_time = sum(times) / len(times)
print(" -- avg_time = {}".format(avg_time))
print(" -- expected_avg_time = {} with err_tol = {}".format(
params.expected_avg_time, params.err_tol))
diff = avg_time - params.expected_avg_time
ok = diff < params.err_tol
if diff < -params.err_tol:
print(" !! WARNING !! Performance has improved by {}, "
"it may be necessary to fine-tune the "
"expected_avg_time = {}".format(
-diff, params.expected_avg_time))
assert ok, " !! ERROR !! Regression detected"

View File

@ -77,9 +77,12 @@ class TPUModelRunner:
parallel_config = self.parallel_config
self.device = device
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
if self.check_recompilation:
self.num_xla_graphs = xr.get_num_cached_compilation_graph()
self.enforce_eager = model_config.enforce_eager
self.num_xla_graphs = 0
self._update_num_xla_graphs("init")
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
self._hidden_states_dtype = self.dtype
@ -180,6 +183,31 @@ class TPUModelRunner:
max_token_size=self.max_num_tokens,
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
if not check_comp:
return
total_cached_graphs = xr.get_num_cached_compilation_graph()
new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
if new_compiled_graphs == 0:
return
logger.info("Add new %d compiled XLA graphs due to %s",
new_compiled_graphs, case_str)
self.num_xla_graphs += new_compiled_graphs
def _verify_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
if not check_comp:
return
curr_cached_graph = xr.get_num_cached_compilation_graph()
assert self.num_xla_graphs == curr_cached_graph, (
"Recompilation after warm up is detected during {}."
" num_xla_graphs = {} curr_cached_graph = {}".format(
case_str, self.num_xla_graphs, curr_cached_graph))
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
"""Update the cached states and the persistent batch with the scheduler
output.
@ -694,12 +722,11 @@ class TPUModelRunner:
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict,
)
# Check there is no new graph compilation, all the graphs should be
# captured and compiled during warming up.
if self.check_recompilation and not self.enforce_eager:
curr_cached_graph = xr.get_num_cached_compilation_graph()
assert self.num_xla_graphs == curr_cached_graph, (
"Recompilation after warm up is detected.")
# Check there are no new graphs compiled - all the graphs should be
# captured and compiled during warm up.
self._verify_num_xla_graphs("execute_model")
return model_runner_output
def load_model(self) -> None:
@ -797,7 +824,9 @@ class TPUModelRunner:
xm.mark_step()
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("model")
logger.info("Compiling sampling with different input shapes.")
start = time.perf_counter()
@ -832,15 +861,9 @@ class TPUModelRunner:
num_reqs_to_sample + 1, self.max_num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
# Record the number cached XLA graph after warming up, this will be
# used for checking there is no additional graph compilation during
# runtime execution.
if self.check_recompilation:
total_cached_graphs = xr.get_num_cached_compilation_graph()
num_compiled_graphs = total_cached_graphs - self.num_xla_graphs
logger.info("Compiled %d XLA graphs.", num_compiled_graphs)
self.num_xla_graphs += num_compiled_graphs
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("sampling")
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""