[V1] TPU CI - Fix test_compilation.py (#15570)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev 2025-03-26 17:51:54 -04:00 committed by GitHub
parent b2e85e26f4
commit 9d119a86ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 41 deletions

View File

@ -22,7 +22,7 @@ docker run --privileged --net host --shm-size=16G -it \
&& export VLLM_USE_V1=1 \
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
&& echo TEST_1 \
&& pytest /workspace/vllm/tests/tpu/test_compilation.py \
&& pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \
&& echo TEST_2 \
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \
&& echo TEST_3 \

View File

@ -5,12 +5,8 @@ import os
import tempfile
import depyf
import pytest
from vllm.config import CompilationLevel
@pytest.mark.skip(reason="Not working; needs investigation.")
def test_tpu_compilation():
temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
@ -22,27 +18,24 @@ def test_tpu_compilation():
"The greatest glory in living lies not in never falling,",
]
answers = [
" or, through inaction, allow a human being to come to harm.",
" what is essential is invisible to the eye.",
" but in rising every time we fall.",
" or, through inaction",
" what is essential ",
" but in rising ",
]
N = 1
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
N = 1
sampling_params = SamplingParams(temperature=0.7,
top_p=1.0,
n=N,
max_tokens=16)
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.
# disable custom dispatcher, let Dynamo takes over
# all the control
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
max_model_len=512,
max_num_seqs=64,
enforce_eager=True,
compilation_config={"level": CompilationLevel.DYNAMO_AS_IS})
max_num_batched_tokens=256,
max_model_len=256,
max_num_seqs=32,
enforce_eager=False)
outputs = llm.generate(prompts, sampling_params)
for output, answer in zip(outputs, answers):
prompt = output.prompt
@ -56,16 +49,11 @@ def test_tpu_compilation():
for i, compiled_code in enumerate(compiled_codes):
print("{} file: {}".format(i + 1, compiled_code))
# We should only trigger Dynamo compilation 4 times:
# 1. forward pass (symbolic)
# 2. compute_logits (symbolic)
# 3. forward pass (shape 16)
# 4. forward pass (shape 32)
# and later calls should not trigger Dynamo compilation again.
# NOTE: It might still trigger XLA compilation.
# We should only trigger Dynamo compilation 2 times:
# 1. Forward pass without kv_caches
# 2. Forward pass with kv_caches
# Check we have 4 compiled codes
assert len(compiled_codes) == 4
assert len(compiled_codes) == 2
kv_cache_prefix = "kv_cache"
attn_prefix = "ragged_paged_attention"
@ -77,24 +65,13 @@ def test_tpu_compilation():
for i, compiled_fn in enumerate(compiled_fns):
print("{} file: {}".format(i + 1, compiled_fn))
# The first compilation is symbolic, so it should not have any kv_caches
# The first compilation should not have any kv_caches
with open(compiled_fns[0]) as f:
content = f.read()
assert kv_cache_prefix not in content
# The second compilation is symbolic, so it should not have any kv_caches
# The second compilation should have kv_caches and the
# ragged_paged_attention
with open(compiled_fns[1]) as f:
content = f.read()
assert kv_cache_prefix not in content
# The third compilation is shape 16, so it should have kv_caches and the
# ragged_paged_attention
with open(compiled_fns[2]) as f:
content = f.read()
assert (kv_cache_prefix in content and attn_prefix in content)
# The forth compilation is shape 32, so it should have kv_caches and the
# ragged_paged_attention
with open(compiled_fns[3]) as f:
content = f.read()
assert (kv_cache_prefix in content and attn_prefix in content)