[V1] TPU CI - Add basic perf regression test (#15414)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
parent
2de4118243
commit
9a2160fa55
@ -21,6 +21,8 @@ docker run --privileged --net host --shm-size=16G -it \
|
||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||
&& export VLLM_USE_V1=1 \
|
||||
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
||||
&& echo TEST_0 \
|
||||
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \
|
||||
&& echo TEST_1 \
|
||||
&& pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \
|
||||
&& echo TEST_2 \
|
||||
|
@ -58,7 +58,7 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
more_args = None
|
||||
if current_platform.is_tpu():
|
||||
# Limit compilation time for TPU V1
|
||||
more_args = "max_num_seqs=64"
|
||||
more_args = "max_model_len=2048,max_num_seqs=64"
|
||||
|
||||
# Add TP test (if provided)
|
||||
if TPU_TP_TEST_STR:
|
||||
|
@ -32,7 +32,7 @@ TENSOR_PARALLEL_SIZES = [1]
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
|
||||
def test_models(
|
||||
def test_basic(
|
||||
vllm_runner: type[VllmRunner],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
model: str,
|
||||
@ -58,4 +58,5 @@ def test_models(
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
output = vllm_outputs[0][1]
|
||||
assert "1024" in output
|
||||
|
||||
assert "1024" in output or "0, 1" in output
|
||||
|
146
tests/v1/tpu/test_perf.py
Normal file
146
tests/v1/tpu/test_perf.py
Normal file
@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A basic performance regression test for TPUs
|
||||
|
||||
Run `pytest tests/v1/tpu/test_perf.py`.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tests.conftest import VllmRunner
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestParams:
|
||||
model: str
|
||||
num_prompts: int
|
||||
prefix_len: int
|
||||
decode_len: int
|
||||
expected_avg_time: float
|
||||
err_tol: float
|
||||
|
||||
|
||||
TEST_PARAMS = [
|
||||
# TODO: Cannot run a series of tests because:
|
||||
# RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed:
|
||||
# open(/dev/vfio/0): Device or resource busy: Device or resource busy;
|
||||
# Couldn't open iommu group /dev/vfio/0
|
||||
# => Investigate
|
||||
|
||||
# TestParams(
|
||||
# model="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
# num_prompts=1,
|
||||
# prefix_len=10,
|
||||
# decode_len=5,
|
||||
# expected_avg_time=0.03,
|
||||
# err_tol=0.01,
|
||||
# ),
|
||||
# TestParams(
|
||||
# model="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
# num_prompts=10,
|
||||
# prefix_len=100,
|
||||
# decode_len=50,
|
||||
# expected_avg_time=0.234,
|
||||
# err_tol=0.020,
|
||||
# ),
|
||||
TestParams(
|
||||
model="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
num_prompts=64,
|
||||
prefix_len=500,
|
||||
decode_len=50,
|
||||
|
||||
# (This is the active CI/CD instance)
|
||||
# commit id: ccb246776d93ef105904a8ec015b3587240a1183
|
||||
# tpu: v5lite (vllm CI/CD)
|
||||
expected_avg_time=1.4,
|
||||
err_tol=0.30,
|
||||
|
||||
# (TODO: There is no v6e in CI/CD currently)
|
||||
# commit id: ccb246776d93ef105904a8ec015b3587240a1183
|
||||
# tpu: v6e
|
||||
# expected_avg_time=1.5,
|
||||
# err_tol=0.20,
|
||||
),
|
||||
]
|
||||
|
||||
NUM_WARMUPS = 5
|
||||
NUM_RUNS = 10
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
MAX_NUM_SEQS = 32
|
||||
GPU_UTIL = 0.9
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||
reason="This is a basic performance test for TPU only")
|
||||
@pytest.mark.parametrize("params", TEST_PARAMS)
|
||||
def test_perf(
|
||||
vllm_runner: type[VllmRunner],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
params: TestParams,
|
||||
) -> None:
|
||||
tokenizer = get_tokenizer(params.model,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True)
|
||||
|
||||
prompts = []
|
||||
for i in range(params.num_prompts):
|
||||
prefix_token_ids = np.random.randint(0,
|
||||
tokenizer.vocab_size,
|
||||
size=params.prefix_len).tolist()
|
||||
prompt = tokenizer.decode(prefix_token_ids)
|
||||
prompts.append(prompt)
|
||||
|
||||
print(
|
||||
"-- Running: num_prompts = {} prefix_len = {} decode_len = {}".format(
|
||||
len(prompts), params.prefix_len, params.decode_len))
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=params.decode_len,
|
||||
temperature=1.0,
|
||||
min_p=0.0)
|
||||
|
||||
with vllm_runner(params.model,
|
||||
max_num_batched_tokens=MAX_MODEL_LEN,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
gpu_memory_utilization=GPU_UTIL,
|
||||
enforce_eager=False,
|
||||
tensor_parallel_size=1) as vllm_model:
|
||||
print(" -- Warmup / Compile")
|
||||
for i in range(NUM_WARMUPS):
|
||||
_ = vllm_model.generate(prompts, sampling_params)
|
||||
|
||||
print(" -- Benchmarking... ")
|
||||
times = []
|
||||
for i in range(NUM_RUNS):
|
||||
start_time = time.time()
|
||||
_ = vllm_model.generate(prompts, sampling_params)
|
||||
times.append(time.time() - start_time)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
|
||||
print(" -- avg_time = {}".format(avg_time))
|
||||
print(" -- expected_avg_time = {} with err_tol = {}".format(
|
||||
params.expected_avg_time, params.err_tol))
|
||||
diff = avg_time - params.expected_avg_time
|
||||
ok = diff < params.err_tol
|
||||
if diff < -params.err_tol:
|
||||
print(" !! WARNING !! Performance has improved by {}, "
|
||||
"it may be necessary to fine-tune the "
|
||||
"expected_avg_time = {}".format(
|
||||
-diff, params.expected_avg_time))
|
||||
|
||||
assert ok, " !! ERROR !! Regression detected"
|
@ -77,9 +77,12 @@ class TPUModelRunner:
|
||||
parallel_config = self.parallel_config
|
||||
self.device = device
|
||||
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
|
||||
if self.check_recompilation:
|
||||
self.num_xla_graphs = xr.get_num_cached_compilation_graph()
|
||||
|
||||
self.enforce_eager = model_config.enforce_eager
|
||||
|
||||
self.num_xla_graphs = 0
|
||||
self._update_num_xla_graphs("init")
|
||||
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
self._hidden_states_dtype = self.dtype
|
||||
@ -180,6 +183,31 @@ class TPUModelRunner:
|
||||
max_token_size=self.max_num_tokens,
|
||||
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
||||
|
||||
def _update_num_xla_graphs(self, case_str):
|
||||
check_comp = self.check_recompilation and not self.enforce_eager
|
||||
if not check_comp:
|
||||
return
|
||||
|
||||
total_cached_graphs = xr.get_num_cached_compilation_graph()
|
||||
new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
|
||||
if new_compiled_graphs == 0:
|
||||
return
|
||||
|
||||
logger.info("Add new %d compiled XLA graphs due to %s",
|
||||
new_compiled_graphs, case_str)
|
||||
self.num_xla_graphs += new_compiled_graphs
|
||||
|
||||
def _verify_num_xla_graphs(self, case_str):
|
||||
check_comp = self.check_recompilation and not self.enforce_eager
|
||||
if not check_comp:
|
||||
return
|
||||
|
||||
curr_cached_graph = xr.get_num_cached_compilation_graph()
|
||||
assert self.num_xla_graphs == curr_cached_graph, (
|
||||
"Recompilation after warm up is detected during {}."
|
||||
" num_xla_graphs = {} curr_cached_graph = {}".format(
|
||||
case_str, self.num_xla_graphs, curr_cached_graph))
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
output.
|
||||
@ -694,12 +722,11 @@ class TPUModelRunner:
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
)
|
||||
# Check there is no new graph compilation, all the graphs should be
|
||||
# captured and compiled during warming up.
|
||||
if self.check_recompilation and not self.enforce_eager:
|
||||
curr_cached_graph = xr.get_num_cached_compilation_graph()
|
||||
assert self.num_xla_graphs == curr_cached_graph, (
|
||||
"Recompilation after warm up is detected.")
|
||||
|
||||
# Check there are no new graphs compiled - all the graphs should be
|
||||
# captured and compiled during warm up.
|
||||
self._verify_num_xla_graphs("execute_model")
|
||||
|
||||
return model_runner_output
|
||||
|
||||
def load_model(self) -> None:
|
||||
@ -797,7 +824,9 @@ class TPUModelRunner:
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
self._update_num_xla_graphs("model")
|
||||
|
||||
logger.info("Compiling sampling with different input shapes.")
|
||||
start = time.perf_counter()
|
||||
@ -832,15 +861,9 @@ class TPUModelRunner:
|
||||
num_reqs_to_sample + 1, self.max_num_reqs)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in %.2f [secs].", end - start)
|
||||
# Record the number cached XLA graph after warming up, this will be
|
||||
# used for checking there is no additional graph compilation during
|
||||
# runtime execution.
|
||||
if self.check_recompilation:
|
||||
total_cached_graphs = xr.get_num_cached_compilation_graph()
|
||||
num_compiled_graphs = total_cached_graphs - self.num_xla_graphs
|
||||
logger.info("Compiled %d XLA graphs.", num_compiled_graphs)
|
||||
self.num_xla_graphs += num_compiled_graphs
|
||||
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
self._update_num_xla_graphs("sampling")
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user