diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py index 591aa9c5..0d7e8d8d 100644 --- a/tests/v1/tpu/test_basic.py +++ b/tests/v1/tpu/test_basic.py @@ -48,7 +48,10 @@ def test_models( with vllm_runner( model, - max_model_len=8192, + # Note: max_num_batched_tokens == 1024 is needed here to + # actually test chunked prompt + max_num_batched_tokens=1024, + max_model_len=8196, gpu_memory_utilization=0.7, max_num_seqs=16, tensor_parallel_size=tensor_parallel_size) as vllm_model: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5401fff2..695e31f7 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -618,6 +618,7 @@ class TPUModelRunner: # Update the cache state concurrently. Code above will not block until # we use `selected_token_ids`. Add mark_step if post-processing changes request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] + discard_sampled_tokens_req_indices = [] for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None req_state = self.requests[req_id] @@ -633,6 +634,10 @@ class TPUModelRunner: # This relies on cuda-specific torch-internal impl details generator.set_offset(generator.get_offset() - 4) + # Record the index of the request that should not be sampled, + # so that we could clear the sampled tokens before returning. + discard_sampled_tokens_req_indices.append(i) + assert all( req_id is not None for req_id in self.input_batch.req_ids[:num_reqs]), "req_ids contains None" @@ -646,11 +651,19 @@ class TPUModelRunner: if max_gen_len == 1: valid_sampled_token_ids = selected_token_ids.tolist() + # Mask out the sampled tokens that should not be sampled. + # TODO: Keep in sync with gpu_model_runner.py, in particular + # the "else" case here + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + + # Append sampled tokens for i, req_state, seq_len in request_seq_lens: token_id = valid_sampled_token_ids[i][0] self.input_batch.token_ids_cpu[i, seq_len] = token_id req_state.output_token_ids.append(token_id) self.input_batch.num_tokens[i] += 1 + else: valid_mask = selected_token_ids != INVALID_TOKEN_ID gen_lens = valid_mask.sum(dim=1).tolist()