[Hardware][TPU] Add check for no additional graph compilation during runtime (#14710)

Signed-off-by: Siyuan Liu <lsiyuan@google.com>
This commit is contained in:
Siyuan Liu 2025-03-20 20:05:28 -07:00 committed by GitHub
parent e588ac237c
commit b15fd2be2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 6 deletions

View File

@ -19,16 +19,18 @@ docker run --privileged --net host --shm-size=16G -it \
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \ vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
&& python3 -m pip install pytest \ && python3 -m pip install pytest \
&& python3 -m pip install lm_eval[api]==0.4.4 \ && python3 -m pip install lm_eval[api]==0.4.4 \
&& export VLLM_USE_V1=1 \
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
&& echo TEST_1 \ && echo TEST_1 \
&& VLLM_USE_V1=1 python3 /workspace/vllm/tests/tpu/test_compilation.py \ && python3 /workspace/vllm/tests/tpu/test_compilation.py \
&& echo TEST_2 \ && echo TEST_2 \
&& VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \ && pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \
&& echo TEST_3 \ && echo TEST_3 \
&& VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \ && pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
&& echo TEST_4 \ && echo TEST_4 \
&& VLLM_USE_V1=1 pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
&& echo TEST_5 \ && echo TEST_5 \
&& VLLM_USE_V1=1 python3 /workspace/vllm/examples/offline_inference/tpu.py" \ && python3 /workspace/vllm/examples/offline_inference/tpu.py" \
# TODO: This test fails because it uses RANDOM_SEED sampling # TODO: This test fails because it uses RANDOM_SEED sampling

View File

@ -45,6 +45,7 @@ if TYPE_CHECKING:
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_XLA_CHECK_RECOMPILATION: bool = False
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_SPMD_WORKER: bool = False
VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False
@ -446,6 +447,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_XLA_CACHE_PATH", "VLLM_XLA_CACHE_PATH",
os.path.join(get_default_cache_root(), "vllm", "xla_cache"), os.path.join(get_default_cache_root(), "vllm", "xla_cache"),
)), )),
# If set, assert on XLA recompilation after each execution step.
"VLLM_XLA_CHECK_RECOMPILATION":
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))),
"VLLM_FUSED_MOE_CHUNK_SIZE": "VLLM_FUSED_MOE_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")),

View File

@ -11,6 +11,7 @@ import torch.nn as nn
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -73,6 +74,10 @@ class TPUModelRunner:
scheduler_config = self.scheduler_config scheduler_config = self.scheduler_config
parallel_config = self.parallel_config parallel_config = self.parallel_config
self.device = device self.device = device
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
if self.check_recompilation:
self.num_xla_graphs = xr.get_num_cached_compilation_graph()
self.enforce_eager = model_config.enforce_eager
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype self.dtype = self.model_config.dtype
@ -671,6 +676,12 @@ class TPUModelRunner:
logprobs=None, logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
) )
# Check there is no new graph compilation, all the graphs should be
# captured and compiled during warming up.
if self.check_recompilation and not self.enforce_eager:
curr_cached_graph = xr.get_num_cached_compilation_graph()
assert self.num_xla_graphs == curr_cached_graph, (
"Recompilation after warm up is detected.")
return model_runner_output return model_runner_output
def load_model(self) -> None: def load_model(self) -> None:
@ -810,6 +821,14 @@ class TPUModelRunner:
xm.wait_device_ops() xm.wait_device_ops()
end = time.perf_counter() end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start) logger.info("Compilation finished in in %.2f [secs].", end - start)
# Record the number cached XLA graph after warming up, this will be
# used for checking there is no additional graph compilation during
# runtime execution.
if self.check_recompilation:
total_cached_graphs = xr.get_num_cached_compilation_graph()
num_compiled_graphs = total_cached_graphs - self.num_xla_graphs
logger.info("Compiled %d XLA graphs.", num_compiled_graphs)
self.num_xla_graphs += num_compiled_graphs
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
""" """