[V1] Refactor num_computed_tokens logic (#15307)

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Cody Yu 2025-03-26 21:54:36 -07:00 committed by GitHub
parent fb22be5817
commit 54aa619459
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 106 additions and 57 deletions

View File

@ -244,7 +244,9 @@ def test_schedule_partial_requests():
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
# Only the first request has a sampled token id because
# the rest requests are still being prefilled.
sampled_token_ids=[[0], [], []],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
@ -266,7 +268,7 @@ def test_schedule_partial_requests():
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool):
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
"""Test scheduling behavior with concurrent partial requests.
This test verifies that: there are multiple long prefill requests in the
@ -304,7 +306,7 @@ def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool):
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
sampled_token_ids=[[] for _ in range(len(requests))],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
@ -325,6 +327,14 @@ def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool):
# Schedule the third step. All three requests are running.
# First and second requests are in the decode stage.
# All the remaining tokens in the third request are processed.
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
scheduler.update_from_output(output1, model_runner_output)
output2 = scheduler.schedule()
assert len(scheduler.running) == 3

View File

@ -231,8 +231,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
Test that the engine can handle multiple concurrent batches.
"""
def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest:
def make_request_with_max_tokens(req_id: int,
max_tokens: int) -> EngineCoreRequest:
request = make_request()
request.request_id = req_id
request.sampling_params.max_tokens = max_tokens
return request
@ -279,6 +281,8 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Avoid all requests being scheduled once.
enable_prefix_caching=False,
max_num_batched_tokens=10,
# Reduce startup time.
enforce_eager=True,
)
vllm_config = engine_args.create_engine_config()
engine_core = EngineCore(vllm_config=vllm_config,
@ -286,13 +290,13 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
executor_class=DummyExecutor)
assert engine_core.batch_queue is not None
# Add two requests in a row.
req = make_request_with_max_tokens(5)
engine_core.add_request(req)
req = make_request_with_max_tokens(5)
engine_core.add_request(req)
# Add two requests in a row. Each request have 12 prompt tokens.
req0 = make_request_with_max_tokens(0, 5)
engine_core.add_request(req0)
req1 = make_request_with_max_tokens(1, 5)
engine_core.add_request(req1)
# First saturate the batch queue.
# Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 1
assert engine_core.step_with_batch_queue() is None

View File

@ -153,9 +153,9 @@ class Scheduler(SchedulerInterface):
num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens)
if self.scheduler_config.long_prefill_token_threshold > 0:
num_new_tokens = min(
num_new_tokens,
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
@ -303,9 +303,9 @@ class Scheduler(SchedulerInterface):
num_computed_tokens -= self.block_size
num_new_tokens = self.block_size
computed_blocks.pop()
if self.scheduler_config.long_prefill_token_threshold > 0:
num_new_tokens = min(
num_new_tokens,
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
@ -433,6 +433,18 @@ class Scheduler(SchedulerInterface):
grammar_bitmask=grammar_bitmask,
)
# Advance the number of computed tokens for the request AFTER
# the request is scheduled.
# 1. The scheduler_output of the current step has to include the
# original number of scheduled tokens to determine input IDs.
# 2. Advance the number of computed tokens here allowing us to
# schedule the prefill request again immediately in the next
# scheduling step.
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
# computed tokens will be adjusted in update_from_output.
for req_id, num_scheduled_token in num_scheduled_tokens.items():
self.requests[req_id].num_computed_tokens += num_scheduled_token
self.finished_req_ids = set()
return scheduler_output
@ -561,28 +573,19 @@ class Scheduler(SchedulerInterface):
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index]
if req_id not in scheduler_output.scheduled_spec_decode_tokens:
# When the request's num_computed_tokens catches up
# its num_tokens, the request generates output tokens.
# Otherwise, we ignore the sampler output for the request.
request.num_computed_tokens += num_tokens_scheduled
assert request.num_computed_tokens <= request.num_tokens
else:
# num_computed_tokens_step represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections.
# It is calculated as:
# num_computed_tokens_step = num_scheduled_tokens -
# num_tokens_rejected,
# where num_tokens_rejected is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens[req_id])
num_computed_tokens_step = num_scheduled_tokens[req_id] - (
len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens += num_computed_tokens_step
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
@ -605,24 +608,26 @@ class Scheduler(SchedulerInterface):
new_logprobs = None
new_token_ids: list[int] = []
if request.num_computed_tokens >= request.num_tokens:
for output_token_id in generated_token_ids:
request.append_output_token_ids(output_token_id)
new_token_ids.append(output_token_id)
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
for output_token_id in generated_token_ids:
request.append_output_token_ids(output_token_id)
new_token_ids.append(output_token_id)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
self._free_request(request)
break
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
self._free_request(request)
break
# Extract sample logprobs if needed.
if request.sampling_params.logprobs is not None:
assert logprobs is not None
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
# Extract sample logprobs if needed.
if (request.sampling_params.logprobs is not None
and logprobs is not None):
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and request.use_structured_output:
# NOTE: structured_output_request

View File

@ -107,14 +107,33 @@ class RejectionSampler(nn.Module):
@staticmethod
def parse_output(
output_token_ids: torch.Tensor,
ignored_req_idxs: list[int],
vocab_size: int,
) -> list[list[int]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
ignored_req_idxs: The indices of the requests that should not be
sampled. This is usually because the request is still in the
prefill phase.
vocab_size: The size of the vocabulary.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size))
ignored_req_idx_set = set(ignored_req_idxs)
outputs = [
row[valid_mask[i]].tolist()
if i not in ignored_req_idx_set else []
for i, row in enumerate(output_token_ids_np)
]
return outputs

View File

@ -1085,8 +1085,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
for i, generator in self.input_batch.generators.items():
req_id = self.input_batch.req_ids[i]
discard_sampled_tokens_req_indices = []
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
@ -1094,7 +1094,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
@ -1114,10 +1119,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids, self.input_batch.vocab_size)
sampled_token_ids,
discard_sampled_tokens_req_indices,
self.input_batch.vocab_size,
)
if not self.use_spec_decode:
spec_token_ids = None