[V1] TPU - Fix the chunked prompt bug (#15713)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
parent
04437e313d
commit
c3f687ac22
@ -48,7 +48,10 @@ def test_models(
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=8192,
|
||||
# Note: max_num_batched_tokens == 1024 is needed here to
|
||||
# actually test chunked prompt
|
||||
max_num_batched_tokens=1024,
|
||||
max_model_len=8196,
|
||||
gpu_memory_utilization=0.7,
|
||||
max_num_seqs=16,
|
||||
tensor_parallel_size=tensor_parallel_size) as vllm_model:
|
||||
|
@ -618,6 +618,7 @@ class TPUModelRunner:
|
||||
# Update the cache state concurrently. Code above will not block until
|
||||
# we use `selected_token_ids`. Add mark_step if post-processing changes
|
||||
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
|
||||
discard_sampled_tokens_req_indices = []
|
||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||
assert req_id is not None
|
||||
req_state = self.requests[req_id]
|
||||
@ -633,6 +634,10 @@ class TPUModelRunner:
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
|
||||
# Record the index of the request that should not be sampled,
|
||||
# so that we could clear the sampled tokens before returning.
|
||||
discard_sampled_tokens_req_indices.append(i)
|
||||
|
||||
assert all(
|
||||
req_id is not None for req_id in
|
||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||
@ -646,11 +651,19 @@ class TPUModelRunner:
|
||||
if max_gen_len == 1:
|
||||
valid_sampled_token_ids = selected_token_ids.tolist()
|
||||
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
# TODO: Keep in sync with gpu_model_runner.py, in particular
|
||||
# the "else" case here
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
|
||||
# Append sampled tokens
|
||||
for i, req_state, seq_len in request_seq_lens:
|
||||
token_id = valid_sampled_token_ids[i][0]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
req_state.output_token_ids.append(token_id)
|
||||
self.input_batch.num_tokens[i] += 1
|
||||
|
||||
else:
|
||||
valid_mask = selected_token_ids != INVALID_TOKEN_ID
|
||||
gen_lens = valid_mask.sum(dim=1).tolist()
|
||||
|
Loading…
x
Reference in New Issue
Block a user