From 5989f4684d62d5cb1852624ce0fd04fc08dd239b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Thu, 17 Apr 2025 20:09:57 +0200 Subject: [PATCH] [TPU][V1] Fix padding recompilation when `max-num-batched-tokens` is not even (#16726) Signed-off-by: NickLucche --- tests/v1/tpu/worker/test_tpu_model_runner.py | 8 ++++++++ vllm/v1/worker/tpu_model_runner.py | 12 +++++++----- vllm/v1/worker/tpu_worker.py | 4 ++-- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 5c7eab0b..5db60609 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -294,11 +294,19 @@ def test_update_states_request_unscheduled(model_runner): def test_get_paddings(): + # Bucketed padding min_token_size, max_token_size, padding_gap = 16, 512, 64 expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512] + actual_paddings = _get_token_paddings(min_token_size, max_token_size, + padding_gap) + + # Bucketed padding with max_token_size not a power of two. + max_token_size = 317 + expected_paddings = [16, 32, 64, 128, 192, 256, 320] actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) assert actual_paddings == expected_paddings + # Exponential padding. max_token_size, padding_gap = 1024, 0 expected_paddings = [16, 32, 64, 128, 256, 512, 1024] diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index c61c449e..b66cd8d2 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -128,10 +128,16 @@ class TPUModelRunner: self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.max_num_tokens = scheduler_config.max_num_batched_tokens # InputBatch needs to work with sampling tensors greater than padding # to avoid dynamic shapes. Also, avoid suboptimal alignment. self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) + self.num_tokens_paddings = _get_token_paddings( + min_token_size=16, + max_token_size=scheduler_config.max_num_batched_tokens, + padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + # In case `max_num_tokens < max(num_tokens_paddings)` use the actual + # padded max value to pre-allocate data structures and pre-compile. + self.max_num_tokens = self.num_tokens_paddings[-1] # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( @@ -211,10 +217,6 @@ class TPUModelRunner: # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) - self.num_tokens_paddings = _get_token_paddings( - min_token_size=16, - max_token_size=self.max_num_tokens, - padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) self.num_reqs_paddings = _get_req_paddings( min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 73c43969..8f2b4acc 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -156,8 +156,8 @@ class TPUWorker: self.vllm_config.compilation_config.static_forward_context, runner_kv_caches) - self.model_runner._dummy_run( - self.scheduler_config.max_num_batched_tokens) + # `max_num_tokens >= max_num_batched_tokens` due to padding. + self.model_runner._dummy_run(self.model_runner.max_num_tokens) # Synchronize before measuring the memory usage. xm.wait_device_ops()