[TPU][V1] Fix padding recompilation when max-num-batched-tokens is not even (#16726)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-04-17 20:09:57 +02:00 committed by GitHub
parent 5125d72f02
commit 5989f4684d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 7 deletions

View File

@ -294,11 +294,19 @@ def test_update_states_request_unscheduled(model_runner):
def test_get_paddings():
# Bucketed padding
min_token_size, max_token_size, padding_gap = 16, 512, 64
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
# Bucketed padding with max_token_size not a power of two.
max_token_size = 317
expected_paddings = [16, 32, 64, 128, 192, 256, 320]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
assert actual_paddings == expected_paddings
# Exponential padding.
max_token_size, padding_gap = 1024, 0
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]

View File

@ -128,10 +128,16 @@ class TPUModelRunner:
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
# InputBatch needs to work with sampling tensors greater than padding
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
self.num_tokens_paddings = _get_token_paddings(
min_token_size=16,
max_token_size=scheduler_config.max_num_batched_tokens,
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
# In case `max_num_tokens < max(num_tokens_paddings)` use the actual
# padded max value to pre-allocate data structures and pre-compile.
self.max_num_tokens = self.num_tokens_paddings[-1]
# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
@ -211,10 +217,6 @@ class TPUModelRunner:
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
self.num_tokens_paddings = _get_token_paddings(
min_token_size=16,
max_token_size=self.max_num_tokens,
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
self.num_reqs_paddings = _get_req_paddings(
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)

View File

@ -156,8 +156,8 @@ class TPUWorker:
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)
self.model_runner._dummy_run(
self.scheduler_config.max_num_batched_tokens)
# `max_num_tokens >= max_num_batched_tokens` due to padding.
self.model_runner._dummy_run(self.model_runner.max_num_tokens)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()