[Hardware][TPU] Add check for no additional graph compilation during runtime (#14710)
Signed-off-by: Siyuan Liu <lsiyuan@google.com>
This commit is contained in:
parent
e588ac237c
commit
b15fd2be2a
@ -19,16 +19,18 @@ docker run --privileged --net host --shm-size=16G -it \
|
|||||||
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
||||||
&& python3 -m pip install pytest \
|
&& python3 -m pip install pytest \
|
||||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||||
|
&& export VLLM_USE_V1=1 \
|
||||||
|
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
||||||
&& echo TEST_1 \
|
&& echo TEST_1 \
|
||||||
&& VLLM_USE_V1=1 python3 /workspace/vllm/tests/tpu/test_compilation.py \
|
&& python3 /workspace/vllm/tests/tpu/test_compilation.py \
|
||||||
&& echo TEST_2 \
|
&& echo TEST_2 \
|
||||||
&& VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \
|
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \
|
||||||
&& echo TEST_3 \
|
&& echo TEST_3 \
|
||||||
&& VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
|
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
|
||||||
&& echo TEST_4 \
|
&& echo TEST_4 \
|
||||||
&& VLLM_USE_V1=1 pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
||||||
&& echo TEST_5 \
|
&& echo TEST_5 \
|
||||||
&& VLLM_USE_V1=1 python3 /workspace/vllm/examples/offline_inference/tpu.py" \
|
&& python3 /workspace/vllm/examples/offline_inference/tpu.py" \
|
||||||
|
|
||||||
|
|
||||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||||
|
@ -45,6 +45,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
|
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
|
||||||
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
|
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
|
||||||
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
||||||
|
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
||||||
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
|
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
|
||||||
VLLM_USE_RAY_SPMD_WORKER: bool = False
|
VLLM_USE_RAY_SPMD_WORKER: bool = False
|
||||||
VLLM_USE_RAY_COMPILED_DAG: bool = False
|
VLLM_USE_RAY_COMPILED_DAG: bool = False
|
||||||
@ -446,6 +447,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_XLA_CACHE_PATH",
|
"VLLM_XLA_CACHE_PATH",
|
||||||
os.path.join(get_default_cache_root(), "vllm", "xla_cache"),
|
os.path.join(get_default_cache_root(), "vllm", "xla_cache"),
|
||||||
)),
|
)),
|
||||||
|
|
||||||
|
# If set, assert on XLA recompilation after each execution step.
|
||||||
|
"VLLM_XLA_CHECK_RECOMPILATION":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))),
|
||||||
"VLLM_FUSED_MOE_CHUNK_SIZE":
|
"VLLM_FUSED_MOE_CHUNK_SIZE":
|
||||||
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")),
|
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")),
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ import torch.nn as nn
|
|||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.runtime as xr
|
import torch_xla.runtime as xr
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -73,6 +74,10 @@ class TPUModelRunner:
|
|||||||
scheduler_config = self.scheduler_config
|
scheduler_config = self.scheduler_config
|
||||||
parallel_config = self.parallel_config
|
parallel_config = self.parallel_config
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
|
||||||
|
if self.check_recompilation:
|
||||||
|
self.num_xla_graphs = xr.get_num_cached_compilation_graph()
|
||||||
|
self.enforce_eager = model_config.enforce_eager
|
||||||
self.pin_memory = is_pin_memory_available()
|
self.pin_memory = is_pin_memory_available()
|
||||||
self.dtype = self.model_config.dtype
|
self.dtype = self.model_config.dtype
|
||||||
|
|
||||||
@ -671,6 +676,12 @@ class TPUModelRunner:
|
|||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
)
|
)
|
||||||
|
# Check there is no new graph compilation, all the graphs should be
|
||||||
|
# captured and compiled during warming up.
|
||||||
|
if self.check_recompilation and not self.enforce_eager:
|
||||||
|
curr_cached_graph = xr.get_num_cached_compilation_graph()
|
||||||
|
assert self.num_xla_graphs == curr_cached_graph, (
|
||||||
|
"Recompilation after warm up is detected.")
|
||||||
return model_runner_output
|
return model_runner_output
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
@ -810,6 +821,14 @@ class TPUModelRunner:
|
|||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||||
|
# Record the number cached XLA graph after warming up, this will be
|
||||||
|
# used for checking there is no additional graph compilation during
|
||||||
|
# runtime execution.
|
||||||
|
if self.check_recompilation:
|
||||||
|
total_cached_graphs = xr.get_num_cached_compilation_graph()
|
||||||
|
num_compiled_graphs = total_cached_graphs - self.num_xla_graphs
|
||||||
|
logger.info("Compiled %d XLA graphs.", num_compiled_graphs)
|
||||||
|
self.num_xla_graphs += num_compiled_graphs
|
||||||
|
|
||||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user