[Hardware][TPU] Add check for no additional graph compilation during runtime (#14710)
Signed-off-by: Siyuan Liu <lsiyuan@google.com>
This commit is contained in:
parent
e588ac237c
commit
b15fd2be2a
@ -19,17 +19,19 @@ docker run --privileged --net host --shm-size=16G -it \
|
||||
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install pytest \
|
||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||
&& export VLLM_USE_V1=1 \
|
||||
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
||||
&& echo TEST_1 \
|
||||
&& VLLM_USE_V1=1 python3 /workspace/vllm/tests/tpu/test_compilation.py \
|
||||
&& python3 /workspace/vllm/tests/tpu/test_compilation.py \
|
||||
&& echo TEST_2 \
|
||||
&& VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \
|
||||
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \
|
||||
&& echo TEST_3 \
|
||||
&& VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
|
||||
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
|
||||
&& echo TEST_4 \
|
||||
&& VLLM_USE_V1=1 pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
||||
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
||||
&& echo TEST_5 \
|
||||
&& VLLM_USE_V1=1 python3 /workspace/vllm/examples/offline_inference/tpu.py" \
|
||||
|
||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py" \
|
||||
|
||||
|
||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
|
||||
|
@ -45,6 +45,7 @@ if TYPE_CHECKING:
|
||||
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
|
||||
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
|
||||
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
||||
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
||||
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
|
||||
VLLM_USE_RAY_SPMD_WORKER: bool = False
|
||||
VLLM_USE_RAY_COMPILED_DAG: bool = False
|
||||
@ -446,6 +447,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_XLA_CACHE_PATH",
|
||||
os.path.join(get_default_cache_root(), "vllm", "xla_cache"),
|
||||
)),
|
||||
|
||||
# If set, assert on XLA recompilation after each execution step.
|
||||
"VLLM_XLA_CHECK_RECOMPILATION":
|
||||
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))),
|
||||
"VLLM_FUSED_MOE_CHUNK_SIZE":
|
||||
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")),
|
||||
|
||||
|
@ -11,6 +11,7 @@ import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig
|
||||
@ -73,6 +74,10 @@ class TPUModelRunner:
|
||||
scheduler_config = self.scheduler_config
|
||||
parallel_config = self.parallel_config
|
||||
self.device = device
|
||||
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
|
||||
if self.check_recompilation:
|
||||
self.num_xla_graphs = xr.get_num_cached_compilation_graph()
|
||||
self.enforce_eager = model_config.enforce_eager
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
|
||||
@ -671,6 +676,12 @@ class TPUModelRunner:
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
)
|
||||
# Check there is no new graph compilation, all the graphs should be
|
||||
# captured and compiled during warming up.
|
||||
if self.check_recompilation and not self.enforce_eager:
|
||||
curr_cached_graph = xr.get_num_cached_compilation_graph()
|
||||
assert self.num_xla_graphs == curr_cached_graph, (
|
||||
"Recompilation after warm up is detected.")
|
||||
return model_runner_output
|
||||
|
||||
def load_model(self) -> None:
|
||||
@ -810,6 +821,14 @@ class TPUModelRunner:
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
# Record the number cached XLA graph after warming up, this will be
|
||||
# used for checking there is no additional graph compilation during
|
||||
# runtime execution.
|
||||
if self.check_recompilation:
|
||||
total_cached_graphs = xr.get_num_cached_compilation_graph()
|
||||
num_compiled_graphs = total_cached_graphs - self.num_xla_graphs
|
||||
logger.info("Compiled %d XLA graphs.", num_compiled_graphs)
|
||||
self.num_xla_graphs += num_compiled_graphs
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user