[TPU][V1] Refine tpu_model_runner to mitigate future recompilation issues (#16275)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao 2025-04-09 17:51:51 -07:00 committed by GitHub
parent 1bff42c4b7
commit a454748544
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 166 additions and 125 deletions

View File

@ -44,7 +44,7 @@ def test_tpu_compilation():
assert generated_text.startswith(answer)
compiled_codes = sorted(
glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
glob.glob(os.path.join(temp_dir, "__transformed_code*for_forward.py")))
for i, compiled_code in enumerate(compiled_codes):
print("{} file: {}".format(i + 1, compiled_code))
@ -52,15 +52,21 @@ def test_tpu_compilation():
# We should only trigger Dynamo compilation 2 times:
# 1. Forward pass without kv_caches
# 2. Forward pass with kv_caches
# Check we have 4 compiled codes
# Check we have 2 compiled codes
assert len(compiled_codes) == 2
kv_cache_prefix = "kv_cache"
attn_prefix = "ragged_paged_attention"
def extract_compiled_index(s):
parts = s.replace(".", "_").split("_")
numbers = [int(part) for part in parts if part.isdigit()]
return numbers[0]
# Check all the compilations are as expected
compiled_fns = sorted(
glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py")))
compiled_fns = sorted(glob.glob(
os.path.join(temp_dir, "__compiled_fn*Captured*.py")),
key=lambda s: extract_compiled_index(s))
for i, compiled_fn in enumerate(compiled_fns):
print("{} file: {}".format(i + 1, compiled_fn))

View File

@ -7,9 +7,9 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.worker.tpu_model_runner import (TPUModelRunner,
_get_padded_token_len,
_get_paddings)
from vllm.v1.worker.tpu_model_runner import (
TPUModelRunner, _get_padded_num_reqs_with_upper_limit,
_get_padded_token_len, _get_req_paddings, _get_token_paddings)
# Mock torch_xla module since it may not be available in the test environments
torch_xla_patcher = mock.patch.dict(
@ -296,16 +296,29 @@ def test_update_states_request_unscheduled(model_runner):
def test_get_paddings():
min_token_size, max_token_size, padding_gap = 16, 512, 64
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
actual_paddings = _get_paddings(min_token_size, max_token_size,
padding_gap)
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
assert actual_paddings == expected_paddings
def test_get_padded_token_len():
min_token_size, max_token_size, padding_gap = 16, 512, 64
paddings = _get_paddings(min_token_size, max_token_size, padding_gap)
paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap)
assert _get_padded_token_len(paddings, 1) == 16
assert _get_padded_token_len(paddings, 16) == 16
assert _get_padded_token_len(paddings, 20) == 32
assert _get_padded_token_len(paddings, 300) == 320
assert _get_padded_token_len(paddings, 512) == 512
def test_get_padded_num_reqs_with_upper_limit():
assert _get_padded_num_reqs_with_upper_limit(3, 32) == 8
assert _get_padded_num_reqs_with_upper_limit(9, 32) == 16
assert _get_padded_num_reqs_with_upper_limit(19, 32) == 32
assert _get_padded_num_reqs_with_upper_limit(17, 28) == 28
def test_get_req_paddings():
assert _get_req_paddings(1, 32) == [8, 16, 32]
assert _get_req_paddings(8, 32) == [8, 16, 32]
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]

View File

@ -3,7 +3,6 @@ from dataclasses import dataclass, field
from typing import Optional
import torch
import torch_xla.core.xla_model as xm
from vllm.v1.worker.gpu_input_batch import InputBatch
@ -24,15 +23,15 @@ class TPUSupportedSamplingMetadata:
# This class exposes a more xla-friendly interface than SamplingMetadata
# on TPU, in particular all arguments should be traceable and no optionals
# are allowed, to avoid graph recompilation on Nones.
temperature: torch.Tensor
temperature: torch.Tensor = None
min_p: torch.Tensor
min_p: torch.Tensor = None
# Still too slow on forward_native!
top_k: torch.Tensor = None
top_p: torch.Tensor = None
# Greedy sampling flag for compiling single xla graph.
all_greedy: torch.Tensor = None
all_greedy: bool = True
# Generator not supported by xla
generators: dict[int,
@ -57,64 +56,58 @@ class TPUSupportedSamplingMetadata:
allowed_token_ids_mask = None
bad_words_token_ids = None
indices_do_sample: torch.Tensor = None
@classmethod
def from_input_batch(
cls, input_batch: InputBatch,
indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata":
cls,
input_batch: InputBatch,
padded_num_reqs: int,
xla_device: torch.device,
generate_params_if_all_greedy: bool = False
) -> "TPUSupportedSamplingMetadata":
"""
Copy sampling tensors slices from `input_batch` to on device tensors.
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
also reuses the on-device persistent tensors managed in `input_batch`
to reduce waste.
ops to CPU and produces tensors of fixed `padded_num_reqs` size.
`indices_do_sample` contains the indices to be fed to the Sampler,
normally one per request, here padded to the closest pre-compiled shape
We expect sampling params tensors to be padded to the same fixed shape.
Eg. 3 requests, tensors padded to 4
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
Args:
input_batch: The input batch containing sampling parameters.
padded_num_reqs: The padded number of requests.
xla_device: The XLA device.
generate_params_if_all_greedy: If True, generate sampling parameters
even if all requests are greedy. this is useful for cases where
we want to pre-compile a graph with sampling parameters, even if
they are not strictly needed for greedy decoding.
"""
num_reqs = input_batch.num_reqs
padded_num_reqs = len(indices_do_sample)
# Early return to avoid unnecessary cpu to tpu copy
if (input_batch.all_greedy is True
and generate_params_if_all_greedy is False):
return cls(all_greedy=True)
def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
fill_val) -> torch.Tensor:
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
num_reqs = input_batch.num_reqs
def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
# Pad value is the default one.
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
# consistent. We can't have flags to skip copies or we'll end up
# recompiling.
copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature,
fill_slice(input_batch.temperature_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["temperature"])
# TODO Temporarily disabled until sampling options are enabled
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
# fill_slice(input_batch.top_p_cpu_tensor)
# fill_slice(input_batch.top_k_cpu_tensor)
fill_slice(input_batch.min_p_cpu_tensor,
DEFAULT_SAMPLING_PARAMS["min_p"])
xm.mark_step()
xm.wait_device_ops()
# Slice persistent device tensors to a fixed pre-compiled padded shape.
return cls(
temperature=input_batch.temperature[:padded_num_reqs],
# Scalar tensor for xla-friendly tracing.
all_greedy=torch.tensor(input_batch.all_greedy,
dtype=torch.bool,
device=input_batch.device),
temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
to(xla_device),
all_greedy=input_batch.all_greedy,
# TODO enable more and avoid returning None values
top_p=None, # input_batch.top_p[:padded_num_reqs],
top_k=None, # input_batch.top_k[:padded_num_reqs],
min_p=input_batch.min_p[:padded_num_reqs],
generators=input_batch.generators,
indices_do_sample=indices_do_sample)
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
xla_device),
generators=input_batch.generators)

View File

@ -32,7 +32,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec, SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput, SamplerOutput)
ModelRunnerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
@ -177,10 +177,12 @@ class TPUModelRunner:
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
self.num_tokens_paddings = _get_paddings(
self.num_tokens_paddings = _get_token_paddings(
min_token_size=16,
max_token_size=self.max_num_tokens,
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
self.num_reqs_paddings = _get_req_paddings(
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
@ -508,7 +510,7 @@ class TPUModelRunner:
# Padded to avoid recompiling when `num_reqs` varies.
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
logits_indices = logits_indices.to(self.device)
return attn_metadata, logits_indices
return attn_metadata, logits_indices, padded_num_reqs
def _scatter_placeholders(
self,
@ -663,7 +665,8 @@ class TPUModelRunner:
mm_embeds = []
# Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs(
scheduler_output)
if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
@ -682,11 +685,6 @@ class TPUModelRunner:
input_ids = self.input_ids
inputs_embeds = None
num_reqs = self.input_batch.num_reqs
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
# are copied to device in chunks of pre-compiled padded shape to
# avoid recompilations.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, logits_indices)
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
@ -694,6 +692,10 @@ class TPUModelRunner:
positions=self.position_ids,
inputs_embeds=inputs_embeds,
)
hidden_states = self.select_hidden_states(hidden_states,
logits_indices)
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, padded_num_reqs, self.device)
selected_token_ids = self.sample_from_hidden(hidden_states,
tpu_sampling_metadata)
# Remove padding on cpu and keep dynamic op outside of xla graph.
@ -857,60 +859,78 @@ class TPUModelRunner:
inputs_embeds=inputs_embeds)
self._hidden_states_dtype = out.dtype
def capture_model(self) -> None:
"""Compile the model."""
def _precompile_backbone(self) -> None:
logger.info("Compiling the model with different input shapes.")
start = time.perf_counter()
for num_tokens in self.num_tokens_paddings:
logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(num_tokens)
xm.mark_step()
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("model")
self._update_num_xla_graphs("model backbone")
def _precompile_select_hidden_states(self) -> None:
# Compile hidden state selection function for bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
logger.info(
"Compiling select_hidden_states with different input shapes.")
start = time.perf_counter()
hsize = self.model_config.get_hidden_size()
for num_tokens in self.num_tokens_paddings:
dummy_hidden = torch.zeros((num_tokens, hsize),
device=self.device,
dtype=self._hidden_states_dtype)
torch._dynamo.mark_dynamic(dummy_hidden, 0)
for num_reqs in self.num_reqs_paddings:
indices = torch.zeros(num_reqs,
dtype=torch.int32,
device=self.device)
torch._dynamo.mark_dynamic(indices, 0)
self.select_hidden_states(dummy_hidden, indices)
logger.info(" -- num_tokens: %d", num_tokens)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("select_hidden_states")
def _precompile_sample_from_hidden(self) -> None:
logger.info("Compiling sampling with different input shapes.")
start = time.perf_counter()
hsize = self.model_config.get_hidden_size()
device = self.device
# Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
for num_tokens in self.num_tokens_paddings:
num_reqs_to_sample = MIN_NUM_SEQS
dummy_hidden = torch.randn((num_tokens, hsize),
device=device,
for num_reqs in self.num_reqs_paddings:
dummy_hidden = torch.zeros((num_reqs, hsize),
device=self.device,
dtype=self._hidden_states_dtype)
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
while True:
indices = torch.zeros(
num_reqs_to_sample,
dtype=torch.int32,
device=device,
)
xm.mark_step()
sampling_meta = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, indices)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample)
out = self.sample_from_hidden(dummy_hidden, sampling_meta)
out = out.cpu()
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs):
break
# Make sure to compile the `max_num_reqs` upper-limit case
num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
num_reqs_to_sample + 1, self.max_num_reqs)
# The first dimension of dummy_hidden cannot be mark_dynamic because
# some operations in the sampler require it to be static.
for all_greedy in [False, True]:
generate_params_if_all_greedy = not all_greedy
sampling_metadata = (
TPUSupportedSamplingMetadata.from_input_batch(
self.input_batch,
num_reqs,
self.device,
generate_params_if_all_greedy,
))
sampling_metadata.all_greedy = all_greedy
self.sample_from_hidden(dummy_hidden, sampling_metadata)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("sampling")
def capture_model(self) -> None:
"""
Precompile all the subgraphs with possible input shapes.
"""
# TODO: precompile encoder
self._precompile_backbone()
self._precompile_select_hidden_states()
self._precompile_sample_from_hidden()
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
@ -962,48 +982,55 @@ class TPUModelRunner:
compiled_model.original_code_object)
compiled_model.compiled_codes.clear()
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def select_hidden_states(self, hidden_states, indices_do_sample):
return hidden_states[indices_do_sample]
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def sample_from_hidden(
self,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor:
"""
Sample with xla-friendly function. This function is to be traced
separately for lighter compilation overhead.
"""
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample]
# SamplingMetadata here for pruning output in LogitsProcessor, disabled.
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
logits = self.model.compute_logits(sample_hidden_states, None)
def sample(
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata
) -> SamplerOutput:
sampler_out = self.sampler(logits, sampling_metadata)
return sampler_out
# Optimized greedy sampling branch, tracing both paths in a single pass
# NOTE all_greedy is a scalar, this is just an optimized if/else.
out_tokens = torch.where(
sampling_metadata.all_greedy,
torch.argmax(logits, dim=-1, keepdim=True),
sample(logits, sampling_metadata).sampled_token_ids)
if sampling_metadata.all_greedy:
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
else:
out_tokens = self.sampler(logits,
sampling_metadata).sampled_token_ids
return out_tokens
def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)
def _get_padded_number(n: int, multiple: int) -> int:
return ((n + multiple - 1) // multiple) * multiple
def get_input_embeddings(self, *args, **kwargs):
return self.model.get_input_embeddings(*args, **kwargs)
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
logger.info("Preparing request paddings:")
# assert min_req_size is power of 2
assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0
paddings: list = []
num = max(MIN_NUM_SEQS, min_req_size)
while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num):
paddings.append(num)
logger.info(" %d", num)
num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size)
return paddings
def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int:
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
return min(res, upper_limit)
def _get_paddings(min_token_size: int, max_token_size: int,
padding_gap: int) -> list[int]:
def _get_token_paddings(min_token_size: int, max_token_size: int,
padding_gap: int) -> list[int]:
"""Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size
@ -1013,18 +1040,20 @@ def _get_paddings(min_token_size: int, max_token_size: int,
first increase the size to twice,
then increase the padding size by padding_gap.
"""
# assert min_token_size is power of 2
assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0
paddings = []
num = min_token_size
if padding_gap == 0:
logger.info("Using exponential paddings:")
logger.info("Using exponential token paddings:")
while num <= max_token_size:
logger.info(" %d", num)
paddings.append(num)
num *= 2
else:
logger.info("Using incremental paddings:")
logger.info("Using incremental token paddings:")
while num <= padding_gap:
logger.info(" %d", num)
paddings.append(num)