diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 6300f16c..c61c449e 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -53,6 +53,41 @@ INVALID_TOKEN_ID = -1 MIN_NUM_SEQS = 8 +######################################################### +# Ways to avoid recompilation +######################################################### +# +# The model executor has two primary components: +# 1. preparing the model and sampler inputs +# 2. executing the model and sampler. +# The core idea is to avoid any TPU computation during input preparation. For +# better compilation tracking and increased flexibility, the model execution and +# sampler are divided into several distinct components. +# +# Below are the detailed steps: +# +# Step 1 +# It is recommended to avoid TPU operations when preparing the model and sampler +# inputs. CPU tensors can be prepared and transferred to the XLA device using +# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids +# compilation. +# +# Step 2 +# The TPU execution should be decomposed into subgraphs (4 at the moment): +# 1. the main model +# 2. selecting hidden states for each request +# 3. sampler +# 4. encoder. +# Each subgraph should be decorated in a torch.compile. This is used to make +# sure that we have the same subgraph topology in both dummy_run and +# xecute_model. The results from these subgraphs should either be passed to +# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for +# subsequent processing on the CPU. +# +# Step 3 +# The dummy_run should be comprehensive, ensuring all potential input shapes and +# branch predictions are included as subgraph inputs to facilitate +# pre-compilation. class TPUModelRunner: def __init__(