[V1] Refactor num_computed_tokens logic (#15307)

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Cody Yu 2025-03-26 21:54:36 -07:00 committed by GitHub
parent fb22be5817
commit 54aa619459
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 106 additions and 57 deletions

View File

@ -244,7 +244,9 @@ def test_schedule_partial_requests():
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests], req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))], # Only the first request has a sampled token id because
# the rest requests are still being prefilled.
sampled_token_ids=[[0], [], []],
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
@ -266,7 +268,7 @@ def test_schedule_partial_requests():
@pytest.mark.parametrize("enable_prefix_caching", [True, False]) @pytest.mark.parametrize("enable_prefix_caching", [True, False])
def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool): def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
"""Test scheduling behavior with concurrent partial requests. """Test scheduling behavior with concurrent partial requests.
This test verifies that: there are multiple long prefill requests in the This test verifies that: there are multiple long prefill requests in the
@ -304,7 +306,7 @@ def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool):
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests], req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))], sampled_token_ids=[[] for _ in range(len(requests))],
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
@ -325,6 +327,14 @@ def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool):
# Schedule the third step. All three requests are running. # Schedule the third step. All three requests are running.
# First and second requests are in the decode stage. # First and second requests are in the decode stage.
# All the remaining tokens in the third request are processed. # All the remaining tokens in the third request are processed.
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
scheduler.update_from_output(output1, model_runner_output) scheduler.update_from_output(output1, model_runner_output)
output2 = scheduler.schedule() output2 = scheduler.schedule()
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3

View File

@ -231,8 +231,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
Test that the engine can handle multiple concurrent batches. Test that the engine can handle multiple concurrent batches.
""" """
def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest: def make_request_with_max_tokens(req_id: int,
max_tokens: int) -> EngineCoreRequest:
request = make_request() request = make_request()
request.request_id = req_id
request.sampling_params.max_tokens = max_tokens request.sampling_params.max_tokens = max_tokens
return request return request
@ -279,6 +281,8 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Avoid all requests being scheduled once. # Avoid all requests being scheduled once.
enable_prefix_caching=False, enable_prefix_caching=False,
max_num_batched_tokens=10, max_num_batched_tokens=10,
# Reduce startup time.
enforce_eager=True,
) )
vllm_config = engine_args.create_engine_config() vllm_config = engine_args.create_engine_config()
engine_core = EngineCore(vllm_config=vllm_config, engine_core = EngineCore(vllm_config=vllm_config,
@ -286,13 +290,13 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
executor_class=DummyExecutor) executor_class=DummyExecutor)
assert engine_core.batch_queue is not None assert engine_core.batch_queue is not None
# Add two requests in a row. # Add two requests in a row. Each request have 12 prompt tokens.
req = make_request_with_max_tokens(5) req0 = make_request_with_max_tokens(0, 5)
engine_core.add_request(req) engine_core.add_request(req0)
req = make_request_with_max_tokens(5) req1 = make_request_with_max_tokens(1, 5)
engine_core.add_request(req) engine_core.add_request(req1)
# First saturate the batch queue. # Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue() is None assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 1 assert engine_core.batch_queue.qsize() == 1
assert engine_core.step_with_batch_queue() is None assert engine_core.step_with_batch_queue() is None

View File

@ -153,9 +153,9 @@ class Scheduler(SchedulerInterface):
num_new_tokens = (request.num_tokens_with_spec - num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens) request.num_computed_tokens)
if self.scheduler_config.long_prefill_token_threshold > 0: if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens = min( num_new_tokens):
num_new_tokens, num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold) self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0 assert num_new_tokens > 0
@ -303,9 +303,9 @@ class Scheduler(SchedulerInterface):
num_computed_tokens -= self.block_size num_computed_tokens -= self.block_size
num_new_tokens = self.block_size num_new_tokens = self.block_size
computed_blocks.pop() computed_blocks.pop()
if self.scheduler_config.long_prefill_token_threshold > 0: if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens = min( num_new_tokens):
num_new_tokens, num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold) self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0 assert num_new_tokens > 0
@ -433,6 +433,18 @@ class Scheduler(SchedulerInterface):
grammar_bitmask=grammar_bitmask, grammar_bitmask=grammar_bitmask,
) )
# Advance the number of computed tokens for the request AFTER
# the request is scheduled.
# 1. The scheduler_output of the current step has to include the
# original number of scheduled tokens to determine input IDs.
# 2. Advance the number of computed tokens here allowing us to
# schedule the prefill request again immediately in the next
# scheduling step.
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
# computed tokens will be adjusted in update_from_output.
for req_id, num_scheduled_token in num_scheduled_tokens.items():
self.requests[req_id].num_computed_tokens += num_scheduled_token
self.finished_req_ids = set() self.finished_req_ids = set()
return scheduler_output return scheduler_output
@ -561,28 +573,19 @@ class Scheduler(SchedulerInterface):
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index] generated_token_ids = sampled_token_ids[req_index]
if req_id not in scheduler_output.scheduled_spec_decode_tokens:
# When the request's num_computed_tokens catches up
# its num_tokens, the request generates output tokens.
# Otherwise, we ignore the sampler output for the request.
request.num_computed_tokens += num_tokens_scheduled
assert request.num_computed_tokens <= request.num_tokens
else:
# num_computed_tokens_step represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections.
# It is calculated as:
# num_computed_tokens_step = num_scheduled_tokens -
# num_tokens_rejected,
# where num_tokens_rejected is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens[req_id])
num_computed_tokens_step = num_scheduled_tokens[req_id] - ( scheduled_spec_token_ids = (
len(scheduled_spec_token_ids) + 1 - scheduler_output.scheduled_spec_decode_tokens.get(req_id))
len(generated_token_ids)) if scheduled_spec_token_ids:
request.num_computed_tokens += num_computed_tokens_step # num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
cached_encoder_input_ids = ( cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request)) self.encoder_cache_manager.get_cached_input_ids(request))
@ -605,24 +608,26 @@ class Scheduler(SchedulerInterface):
new_logprobs = None new_logprobs = None
new_token_ids: list[int] = [] new_token_ids: list[int] = []
if request.num_computed_tokens >= request.num_tokens: # Append generated tokens and check for stop. Note that if
for output_token_id in generated_token_ids: # a request is still being prefilled, we expect the model runner
request.append_output_token_ids(output_token_id) # to return empty token ids for the request.
new_token_ids.append(output_token_id) for output_token_id in generated_token_ids:
request.append_output_token_ids(output_token_id)
new_token_ids.append(output_token_id)
# Check for stop and update request state. # Check for stop and update request state.
# This must be called before we make the EngineCoreOutput. # This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len) stopped = check_stop(request, self.max_model_len)
if stopped: if stopped:
self._free_request(request) self._free_request(request)
break break
# Extract sample logprobs if needed. # Extract sample logprobs if needed.
if request.sampling_params.logprobs is not None: if (request.sampling_params.logprobs is not None
assert logprobs is not None and logprobs is not None):
# NOTE: once we support N tokens per step (spec decode), # NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1. # the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1) new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and request.use_structured_output: if new_token_ids and request.use_structured_output:
# NOTE: structured_output_request # NOTE: structured_output_request

View File

@ -107,14 +107,33 @@ class RejectionSampler(nn.Module):
@staticmethod @staticmethod
def parse_output( def parse_output(
output_token_ids: torch.Tensor, output_token_ids: torch.Tensor,
ignored_req_idxs: list[int],
vocab_size: int, vocab_size: int,
) -> list[list[int]]: ) -> list[list[int]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
ignored_req_idxs: The indices of the requests that should not be
sampled. This is usually because the request is still in the
prefill phase.
vocab_size: The size of the vocabulary.
Returns:
A list of lists of token IDs.
"""
output_token_ids_np = output_token_ids.cpu().numpy() output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens. # Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size)) (output_token_ids_np < vocab_size))
ignored_req_idx_set = set(ignored_req_idxs)
outputs = [ outputs = [
row[valid_mask[i]].tolist() row[valid_mask[i]].tolist()
if i not in ignored_req_idx_set else []
for i, row in enumerate(output_token_ids_np) for i, row in enumerate(output_token_ids_np)
] ]
return outputs return outputs

View File

@ -1085,8 +1085,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): The following loop can be slow since it iterates over # TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize. # the requests one by one. Optimize.
for i, generator in self.input_batch.generators.items(): discard_sampled_tokens_req_indices = []
req_id = self.input_batch.req_ids[i] for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id] req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens + seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id]) scheduler_output.num_scheduled_tokens[req_id])
@ -1094,7 +1094,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Ignore the sampled token for partial prefills. # Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled. # Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details # This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4) generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
# NOTE: GPU -> CPU Sync happens here. # NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point. # Move as many CPU operations as possible before this sync point.
@ -1114,10 +1119,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if max_gen_len == 1: if max_gen_len == 1:
# No spec decode tokens. # No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist() valid_sampled_token_ids = sampled_token_ids.tolist()
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
else: else:
# Includes spec decode tokens. # Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output( valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids, self.input_batch.vocab_size) sampled_token_ids,
discard_sampled_tokens_req_indices,
self.input_batch.vocab_size,
)
if not self.use_spec_decode: if not self.use_spec_decode:
spec_token_ids = None spec_token_ids = None