[V1] Refactor num_computed_tokens logic (#15307)
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
fb22be5817
commit
54aa619459
@ -244,7 +244,9 @@ def test_schedule_partial_requests():
|
|||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=[request.request_id for request in requests],
|
req_ids=[request.request_id for request in requests],
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
# Only the first request has a sampled token id because
|
||||||
|
# the rest requests are still being prefilled.
|
||||||
|
sampled_token_ids=[[0], [], []],
|
||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
@ -266,7 +268,7 @@ def test_schedule_partial_requests():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
||||||
def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool):
|
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||||
"""Test scheduling behavior with concurrent partial requests.
|
"""Test scheduling behavior with concurrent partial requests.
|
||||||
|
|
||||||
This test verifies that: there are multiple long prefill requests in the
|
This test verifies that: there are multiple long prefill requests in the
|
||||||
@ -304,7 +306,7 @@ def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool):
|
|||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=[request.request_id for request in requests],
|
req_ids=[request.request_id for request in requests],
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
sampled_token_ids=[[] for _ in range(len(requests))],
|
||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
@ -325,6 +327,14 @@ def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool):
|
|||||||
# Schedule the third step. All three requests are running.
|
# Schedule the third step. All three requests are running.
|
||||||
# First and second requests are in the decode stage.
|
# First and second requests are in the decode stage.
|
||||||
# All the remaining tokens in the third request are processed.
|
# All the remaining tokens in the third request are processed.
|
||||||
|
model_runner_output = ModelRunnerOutput(
|
||||||
|
req_ids=[request.request_id for request in requests],
|
||||||
|
req_id_to_index=req_to_index,
|
||||||
|
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
scheduler.update_from_output(output1, model_runner_output)
|
scheduler.update_from_output(output1, model_runner_output)
|
||||||
output2 = scheduler.schedule()
|
output2 = scheduler.schedule()
|
||||||
assert len(scheduler.running) == 3
|
assert len(scheduler.running) == 3
|
||||||
|
@ -231,8 +231,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
|||||||
Test that the engine can handle multiple concurrent batches.
|
Test that the engine can handle multiple concurrent batches.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest:
|
def make_request_with_max_tokens(req_id: int,
|
||||||
|
max_tokens: int) -> EngineCoreRequest:
|
||||||
request = make_request()
|
request = make_request()
|
||||||
|
request.request_id = req_id
|
||||||
request.sampling_params.max_tokens = max_tokens
|
request.sampling_params.max_tokens = max_tokens
|
||||||
return request
|
return request
|
||||||
|
|
||||||
@ -279,6 +281,8 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
|||||||
# Avoid all requests being scheduled once.
|
# Avoid all requests being scheduled once.
|
||||||
enable_prefix_caching=False,
|
enable_prefix_caching=False,
|
||||||
max_num_batched_tokens=10,
|
max_num_batched_tokens=10,
|
||||||
|
# Reduce startup time.
|
||||||
|
enforce_eager=True,
|
||||||
)
|
)
|
||||||
vllm_config = engine_args.create_engine_config()
|
vllm_config = engine_args.create_engine_config()
|
||||||
engine_core = EngineCore(vllm_config=vllm_config,
|
engine_core = EngineCore(vllm_config=vllm_config,
|
||||||
@ -286,13 +290,13 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
|||||||
executor_class=DummyExecutor)
|
executor_class=DummyExecutor)
|
||||||
assert engine_core.batch_queue is not None
|
assert engine_core.batch_queue is not None
|
||||||
|
|
||||||
# Add two requests in a row.
|
# Add two requests in a row. Each request have 12 prompt tokens.
|
||||||
req = make_request_with_max_tokens(5)
|
req0 = make_request_with_max_tokens(0, 5)
|
||||||
engine_core.add_request(req)
|
engine_core.add_request(req0)
|
||||||
req = make_request_with_max_tokens(5)
|
req1 = make_request_with_max_tokens(1, 5)
|
||||||
engine_core.add_request(req)
|
engine_core.add_request(req1)
|
||||||
|
|
||||||
# First saturate the batch queue.
|
# Schedule Batch 1: (10, req0)
|
||||||
assert engine_core.step_with_batch_queue() is None
|
assert engine_core.step_with_batch_queue() is None
|
||||||
assert engine_core.batch_queue.qsize() == 1
|
assert engine_core.batch_queue.qsize() == 1
|
||||||
assert engine_core.step_with_batch_queue() is None
|
assert engine_core.step_with_batch_queue() is None
|
||||||
|
@ -153,9 +153,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
num_new_tokens = (request.num_tokens_with_spec -
|
num_new_tokens = (request.num_tokens_with_spec -
|
||||||
request.num_computed_tokens)
|
request.num_computed_tokens)
|
||||||
if self.scheduler_config.long_prefill_token_threshold > 0:
|
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
||||||
num_new_tokens = min(
|
num_new_tokens):
|
||||||
num_new_tokens,
|
num_new_tokens = (
|
||||||
self.scheduler_config.long_prefill_token_threshold)
|
self.scheduler_config.long_prefill_token_threshold)
|
||||||
num_new_tokens = min(num_new_tokens, token_budget)
|
num_new_tokens = min(num_new_tokens, token_budget)
|
||||||
assert num_new_tokens > 0
|
assert num_new_tokens > 0
|
||||||
@ -303,9 +303,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
num_computed_tokens -= self.block_size
|
num_computed_tokens -= self.block_size
|
||||||
num_new_tokens = self.block_size
|
num_new_tokens = self.block_size
|
||||||
computed_blocks.pop()
|
computed_blocks.pop()
|
||||||
if self.scheduler_config.long_prefill_token_threshold > 0:
|
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
||||||
num_new_tokens = min(
|
num_new_tokens):
|
||||||
num_new_tokens,
|
num_new_tokens = (
|
||||||
self.scheduler_config.long_prefill_token_threshold)
|
self.scheduler_config.long_prefill_token_threshold)
|
||||||
num_new_tokens = min(num_new_tokens, token_budget)
|
num_new_tokens = min(num_new_tokens, token_budget)
|
||||||
assert num_new_tokens > 0
|
assert num_new_tokens > 0
|
||||||
@ -433,6 +433,18 @@ class Scheduler(SchedulerInterface):
|
|||||||
grammar_bitmask=grammar_bitmask,
|
grammar_bitmask=grammar_bitmask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Advance the number of computed tokens for the request AFTER
|
||||||
|
# the request is scheduled.
|
||||||
|
# 1. The scheduler_output of the current step has to include the
|
||||||
|
# original number of scheduled tokens to determine input IDs.
|
||||||
|
# 2. Advance the number of computed tokens here allowing us to
|
||||||
|
# schedule the prefill request again immediately in the next
|
||||||
|
# scheduling step.
|
||||||
|
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
|
||||||
|
# computed tokens will be adjusted in update_from_output.
|
||||||
|
for req_id, num_scheduled_token in num_scheduled_tokens.items():
|
||||||
|
self.requests[req_id].num_computed_tokens += num_scheduled_token
|
||||||
|
|
||||||
self.finished_req_ids = set()
|
self.finished_req_ids = set()
|
||||||
return scheduler_output
|
return scheduler_output
|
||||||
|
|
||||||
@ -561,28 +573,19 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
req_index = model_runner_output.req_id_to_index[req_id]
|
req_index = model_runner_output.req_id_to_index[req_id]
|
||||||
generated_token_ids = sampled_token_ids[req_index]
|
generated_token_ids = sampled_token_ids[req_index]
|
||||||
if req_id not in scheduler_output.scheduled_spec_decode_tokens:
|
|
||||||
# When the request's num_computed_tokens catches up
|
|
||||||
# its num_tokens, the request generates output tokens.
|
|
||||||
# Otherwise, we ignore the sampler output for the request.
|
|
||||||
request.num_computed_tokens += num_tokens_scheduled
|
|
||||||
assert request.num_computed_tokens <= request.num_tokens
|
|
||||||
else:
|
|
||||||
# num_computed_tokens_step represents the number of tokens
|
|
||||||
# processed in the current step, considering scheduled
|
|
||||||
# tokens and rejections.
|
|
||||||
# It is calculated as:
|
|
||||||
# num_computed_tokens_step = num_scheduled_tokens -
|
|
||||||
# num_tokens_rejected,
|
|
||||||
# where num_tokens_rejected is given by:
|
|
||||||
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
|
|
||||||
scheduled_spec_token_ids = (
|
|
||||||
scheduler_output.scheduled_spec_decode_tokens[req_id])
|
|
||||||
|
|
||||||
num_computed_tokens_step = num_scheduled_tokens[req_id] - (
|
scheduled_spec_token_ids = (
|
||||||
len(scheduled_spec_token_ids) + 1 -
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||||
len(generated_token_ids))
|
if scheduled_spec_token_ids:
|
||||||
request.num_computed_tokens += num_computed_tokens_step
|
# num_computed_tokens represents the number of tokens
|
||||||
|
# processed in the current step, considering scheduled
|
||||||
|
# tokens and rejections. If some tokens are rejected,
|
||||||
|
# num_computed_tokens is decreased by the number of rejected
|
||||||
|
# tokens, where is given by:
|
||||||
|
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
|
||||||
|
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
|
||||||
|
len(generated_token_ids))
|
||||||
|
request.num_computed_tokens -= num_tokens_rejected
|
||||||
|
|
||||||
cached_encoder_input_ids = (
|
cached_encoder_input_ids = (
|
||||||
self.encoder_cache_manager.get_cached_input_ids(request))
|
self.encoder_cache_manager.get_cached_input_ids(request))
|
||||||
@ -605,24 +608,26 @@ class Scheduler(SchedulerInterface):
|
|||||||
new_logprobs = None
|
new_logprobs = None
|
||||||
new_token_ids: list[int] = []
|
new_token_ids: list[int] = []
|
||||||
|
|
||||||
if request.num_computed_tokens >= request.num_tokens:
|
# Append generated tokens and check for stop. Note that if
|
||||||
for output_token_id in generated_token_ids:
|
# a request is still being prefilled, we expect the model runner
|
||||||
request.append_output_token_ids(output_token_id)
|
# to return empty token ids for the request.
|
||||||
new_token_ids.append(output_token_id)
|
for output_token_id in generated_token_ids:
|
||||||
|
request.append_output_token_ids(output_token_id)
|
||||||
|
new_token_ids.append(output_token_id)
|
||||||
|
|
||||||
# Check for stop and update request state.
|
# Check for stop and update request state.
|
||||||
# This must be called before we make the EngineCoreOutput.
|
# This must be called before we make the EngineCoreOutput.
|
||||||
stopped = check_stop(request, self.max_model_len)
|
stopped = check_stop(request, self.max_model_len)
|
||||||
if stopped:
|
if stopped:
|
||||||
self._free_request(request)
|
self._free_request(request)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Extract sample logprobs if needed.
|
# Extract sample logprobs if needed.
|
||||||
if request.sampling_params.logprobs is not None:
|
if (request.sampling_params.logprobs is not None
|
||||||
assert logprobs is not None
|
and logprobs is not None):
|
||||||
# NOTE: once we support N tokens per step (spec decode),
|
# NOTE: once we support N tokens per step (spec decode),
|
||||||
# the outer lists can be of length > 1.
|
# the outer lists can be of length > 1.
|
||||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||||
|
|
||||||
if new_token_ids and request.use_structured_output:
|
if new_token_ids and request.use_structured_output:
|
||||||
# NOTE: structured_output_request
|
# NOTE: structured_output_request
|
||||||
|
@ -107,14 +107,33 @@ class RejectionSampler(nn.Module):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_output(
|
def parse_output(
|
||||||
output_token_ids: torch.Tensor,
|
output_token_ids: torch.Tensor,
|
||||||
|
ignored_req_idxs: list[int],
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
) -> list[list[int]]:
|
) -> list[list[int]]:
|
||||||
|
"""Parse the output of the rejection sampler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_token_ids: The sampled token IDs in shape
|
||||||
|
[batch_size, max_spec_len + 1]. The rejected tokens are
|
||||||
|
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
|
||||||
|
and will be filtered out in this function.
|
||||||
|
ignored_req_idxs: The indices of the requests that should not be
|
||||||
|
sampled. This is usually because the request is still in the
|
||||||
|
prefill phase.
|
||||||
|
vocab_size: The size of the vocabulary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of lists of token IDs.
|
||||||
|
"""
|
||||||
output_token_ids_np = output_token_ids.cpu().numpy()
|
output_token_ids_np = output_token_ids.cpu().numpy()
|
||||||
# Create mask for valid tokens.
|
# Create mask for valid tokens.
|
||||||
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
|
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
|
||||||
(output_token_ids_np < vocab_size))
|
(output_token_ids_np < vocab_size))
|
||||||
|
|
||||||
|
ignored_req_idx_set = set(ignored_req_idxs)
|
||||||
outputs = [
|
outputs = [
|
||||||
row[valid_mask[i]].tolist()
|
row[valid_mask[i]].tolist()
|
||||||
|
if i not in ignored_req_idx_set else []
|
||||||
for i, row in enumerate(output_token_ids_np)
|
for i, row in enumerate(output_token_ids_np)
|
||||||
]
|
]
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -1085,8 +1085,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||||
# the requests one by one. Optimize.
|
# the requests one by one. Optimize.
|
||||||
for i, generator in self.input_batch.generators.items():
|
discard_sampled_tokens_req_indices = []
|
||||||
req_id = self.input_batch.req_ids[i]
|
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
seq_len = (req_state.num_computed_tokens +
|
seq_len = (req_state.num_computed_tokens +
|
||||||
scheduler_output.num_scheduled_tokens[req_id])
|
scheduler_output.num_scheduled_tokens[req_id])
|
||||||
@ -1094,7 +1094,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Ignore the sampled token for partial prefills.
|
# Ignore the sampled token for partial prefills.
|
||||||
# Rewind the generator state as if the token was not sampled.
|
# Rewind the generator state as if the token was not sampled.
|
||||||
# This relies on cuda-specific torch-internal impl details
|
# This relies on cuda-specific torch-internal impl details
|
||||||
generator.set_offset(generator.get_offset() - 4)
|
generator = self.input_batch.generators.get(i)
|
||||||
|
if generator is not None:
|
||||||
|
generator.set_offset(generator.get_offset() - 4)
|
||||||
|
# Record the index of the request that should not be sampled,
|
||||||
|
# so that we could clear the sampled tokens before returning.
|
||||||
|
discard_sampled_tokens_req_indices.append(i)
|
||||||
|
|
||||||
# NOTE: GPU -> CPU Sync happens here.
|
# NOTE: GPU -> CPU Sync happens here.
|
||||||
# Move as many CPU operations as possible before this sync point.
|
# Move as many CPU operations as possible before this sync point.
|
||||||
@ -1114,10 +1119,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if max_gen_len == 1:
|
if max_gen_len == 1:
|
||||||
# No spec decode tokens.
|
# No spec decode tokens.
|
||||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||||
|
# Mask out the sampled tokens that should not be sampled.
|
||||||
|
for i in discard_sampled_tokens_req_indices:
|
||||||
|
valid_sampled_token_ids[i].clear()
|
||||||
else:
|
else:
|
||||||
# Includes spec decode tokens.
|
# Includes spec decode tokens.
|
||||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||||
sampled_token_ids, self.input_batch.vocab_size)
|
sampled_token_ids,
|
||||||
|
discard_sampled_tokens_req_indices,
|
||||||
|
self.input_batch.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
if not self.use_spec_decode:
|
if not self.use_spec_decode:
|
||||||
spec_token_ids = None
|
spec_token_ids = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user