[V1] TPU - Fix the chunked prompt bug (#15713)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev 2025-03-28 16:19:04 -04:00 committed by GitHub
parent 04437e313d
commit c3f687ac22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 1 deletions

View File

@ -48,7 +48,10 @@ def test_models(
with vllm_runner(
model,
max_model_len=8192,
# Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt
max_num_batched_tokens=1024,
max_model_len=8196,
gpu_memory_utilization=0.7,
max_num_seqs=16,
tensor_parallel_size=tensor_parallel_size) as vllm_model:

View File

@ -618,6 +618,7 @@ class TPUModelRunner:
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
discard_sampled_tokens_req_indices = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
req_state = self.requests[req_id]
@ -633,6 +634,10 @@ class TPUModelRunner:
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
@ -646,11 +651,19 @@ class TPUModelRunner:
if max_gen_len == 1:
valid_sampled_token_ids = selected_token_ids.tolist()
# Mask out the sampled tokens that should not be sampled.
# TODO: Keep in sync with gpu_model_runner.py, in particular
# the "else" case here
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
# Append sampled tokens
for i, req_state, seq_len in request_seq_lens:
token_id = valid_sampled_token_ids[i][0]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id)
self.input_batch.num_tokens[i] += 1
else:
valid_mask = selected_token_ids != INVALID_TOKEN_ID
gen_lens = valid_mask.sum(dim=1).tolist()