[V1] Remove scheduling constraint on partial requests (#12674)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
d1ca7df84d
commit
18a88fcccc
214
tests/v1/core/test_scheduler.py
Normal file
214
tests/v1/core/test_scheduler.py
Normal file
@ -0,0 +1,214 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
model: str = "facebook/opt-125m",
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
) -> Scheduler:
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_num_batched_tokens,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
task="auto",
|
||||
tokenizer=model,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=16,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
)
|
||||
cache_config.num_gpu_blocks = 10000
|
||||
return Scheduler(scheduler_config,
|
||||
model_config,
|
||||
cache_config,
|
||||
lora_config=None)
|
||||
|
||||
|
||||
def create_requests(
|
||||
num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[List[PlaceholderRange]] = None,
|
||||
):
|
||||
sampling_params = SamplingParams()
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
|
||||
else:
|
||||
mm_position = None
|
||||
mm_inputs = None
|
||||
request = Request(
|
||||
request_id=f"{i}",
|
||||
prompt=None,
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
multi_modal_inputs=mm_inputs,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
)
|
||||
requests.append(request)
|
||||
return requests
|
||||
|
||||
|
||||
def test_add_requests():
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.add_request(request)
|
||||
assert request.request_id in scheduler.requests
|
||||
assert len(scheduler.waiting) == i + 1
|
||||
|
||||
|
||||
def test_finish_request():
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.finish_requests(request.request_id,
|
||||
RequestStatus.FINISHED_ABORTED)
|
||||
assert request.request_id not in scheduler.requests
|
||||
assert len(scheduler.waiting) == 9 - i
|
||||
|
||||
|
||||
def test_get_num_unfinished_requests():
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.finish_requests(request.request_id,
|
||||
RequestStatus.FINISHED_STOPPED)
|
||||
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
|
||||
|
||||
|
||||
def test_schedule():
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Test initial scheduling
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == len(requests)
|
||||
assert len(output.scheduled_cached_reqs) == 0
|
||||
assert len(output.finished_req_ids) == 0
|
||||
# Verify all requests are scheduled.
|
||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
||||
|
||||
# Verify requests moved from waiting to running
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == len(requests)
|
||||
for i, request in enumerate(requests):
|
||||
assert scheduler.running[i] == request
|
||||
|
||||
|
||||
def test_schedule_multimodal_requests():
|
||||
scheduler = create_scheduler(model="llava-hf/llava-1.5-7b-hf")
|
||||
mm_positions = [[PlaceholderRange(offset=i, length=100)]
|
||||
for i in range(10)]
|
||||
requests = create_requests(
|
||||
num_requests=10,
|
||||
num_tokens=200,
|
||||
mm_positions=mm_positions,
|
||||
)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == len(requests)
|
||||
assert len(output.scheduled_cached_reqs) == 0
|
||||
assert len(output.finished_req_ids) == 0
|
||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
||||
assert len(output.scheduled_encoder_inputs) == 10
|
||||
for req_id, encoder_input in output.scheduled_encoder_inputs.items():
|
||||
assert len(encoder_input) == 1
|
||||
|
||||
|
||||
def test_schedule_partial_requests():
|
||||
"""Test scheduling behavior with partial requests.
|
||||
|
||||
This test verifies that:
|
||||
1. The scheduler can handle multiple partial requests in a single step when
|
||||
constrained by encoder budget.
|
||||
2. A request in RUNNING state may be unscheduled in subsequent steps if
|
||||
there is insufficient encoder budget.
|
||||
"""
|
||||
scheduler = create_scheduler(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
max_num_batched_tokens=1024,
|
||||
)
|
||||
mm_positions = [[PlaceholderRange(offset=100, length=600)]
|
||||
for _ in range(3)]
|
||||
requests = create_requests(
|
||||
num_requests=3,
|
||||
num_tokens=800,
|
||||
mm_positions=mm_positions,
|
||||
)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 3
|
||||
assert len(output.scheduled_cached_reqs) == 0
|
||||
assert len(output.finished_req_ids) == 0
|
||||
|
||||
assert scheduler.max_num_encoder_input_tokens == 1024
|
||||
# The first request is scheduled fully.
|
||||
assert output.num_scheduled_tokens[requests[0].request_id] == 800
|
||||
# The second request is scheduled partially.
|
||||
# The <img> tokens are not scheduled because of the encoder budget.
|
||||
assert output.num_scheduled_tokens[requests[1].request_id] == 100
|
||||
# The third request is also scheduled partially.
|
||||
# The <img> tokens are not scheduled because of the encoder budget.
|
||||
assert output.num_scheduled_tokens[requests[2].request_id] == 100
|
||||
req_to_index = {
|
||||
request.request_id: i
|
||||
for i, request in enumerate(requests)
|
||||
}
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[0] * len(requests),
|
||||
logprob_token_ids_cpu=None,
|
||||
logprobs_cpu=None,
|
||||
)
|
||||
scheduler.update_from_output(output, model_runner_output)
|
||||
|
||||
# Schedule the next step.
|
||||
# Only the first and second requests are scheduled.
|
||||
# The third request is in the RUNNING state but not scheduled in this step
|
||||
# because of the encoder budget.
|
||||
output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 3
|
||||
assert len(output.scheduled_new_reqs) == 0
|
||||
assert len(output.scheduled_cached_reqs) == 2
|
||||
assert len(output.finished_req_ids) == 0
|
||||
assert output.num_scheduled_tokens[requests[0].request_id] == 1
|
||||
assert output.num_scheduled_tokens[requests[1].request_id] == 700
|
||||
assert requests[2].request_id not in output.num_scheduled_tokens
|
@ -67,10 +67,10 @@ class Scheduler:
|
||||
# This is flushed at the end of each scheduling step.
|
||||
self.finished_req_ids: Set[str] = set()
|
||||
|
||||
# OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
# Request id -> RunningRequestData
|
||||
self.running_reqs_data: Dict[str, RunningRequestData] = {}
|
||||
# Request id -> CachedRequestData
|
||||
self._cached_reqs_data: Dict[str, CachedRequestData] = {}
|
||||
|
||||
# Encoder-related.
|
||||
# Calculate encoder cache size if applicable
|
||||
@ -115,17 +115,8 @@ class Scheduler:
|
||||
encoder_budget = self.max_num_encoder_input_tokens
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
# NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be
|
||||
# in the "partial" state, where the request has some tokens computed
|
||||
# but not all. The constraint is due to the persistent batch in the
|
||||
# V1 model runner.
|
||||
# TODO(woosuk): Remove this constraint after refactoring model runner.
|
||||
has_partial_request = False
|
||||
req_index = 0
|
||||
while req_index < len(self.running):
|
||||
# Only the last request in the RUNNING queue can be "partial".
|
||||
assert not has_partial_request
|
||||
assert token_budget > 0
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[req_index]
|
||||
num_new_tokens = request.num_tokens - request.num_computed_tokens
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
@ -137,7 +128,14 @@ class Scheduler:
|
||||
request.num_computed_tokens,
|
||||
num_new_tokens,
|
||||
encoder_budget))
|
||||
assert num_new_tokens > 0
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled because the encoder budget
|
||||
# or the encoder cache is exhausted.
|
||||
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
||||
# we do not strictly follow the FCFS scheduling policy and
|
||||
# allow the lower-priority requests to be scheduled.
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
@ -172,8 +170,6 @@ class Scheduler:
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
has_partial_request = (request.num_computed_tokens + num_new_tokens
|
||||
< request.num_tokens)
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
@ -186,13 +182,9 @@ class Scheduler:
|
||||
|
||||
# Next, schedule the WAITING requests.
|
||||
if not preempted_reqs:
|
||||
while self.waiting:
|
||||
if has_partial_request:
|
||||
break
|
||||
while self.waiting and token_budget > 0:
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
break
|
||||
if token_budget == 0:
|
||||
break
|
||||
|
||||
request = self.waiting[0]
|
||||
# Get already-cached tokens.
|
||||
@ -249,8 +241,6 @@ class Scheduler:
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
has_partial_request = (num_computed_tokens + num_new_tokens
|
||||
< request.num_tokens)
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
@ -266,8 +256,11 @@ class Scheduler:
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
assert token_budget >= 0
|
||||
assert len(self.running) <= self.max_num_running_reqs
|
||||
# Since some requests in the RUNNING queue may not be scheduled in
|
||||
# this step, the total number of scheduled requests can be smaller than
|
||||
# len(self.running).
|
||||
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
|
||||
len(scheduled_running_reqs) == len(self.running))
|
||||
len(scheduled_running_reqs) <= len(self.running))
|
||||
|
||||
# Get the longest common prefix among all requests in the running queue.
|
||||
# This can be potentially used for cascade attention.
|
||||
@ -286,25 +279,28 @@ class Scheduler:
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
resumed_reqs_data = [
|
||||
ResumedRequestData.from_request(
|
||||
req, req_to_new_block_ids[req.request_id],
|
||||
req.num_computed_tokens) for req in scheduled_resumed_reqs
|
||||
self._make_cached_request_data(
|
||||
req,
|
||||
req_to_new_block_ids[req.request_id],
|
||||
req.num_computed_tokens,
|
||||
resumed_from_preemption=True,
|
||||
) for req in scheduled_resumed_reqs
|
||||
]
|
||||
running_reqs_data = [
|
||||
self._make_running_request_data(
|
||||
req, req_to_new_block_ids[req.request_id],
|
||||
req.num_computed_tokens) for req in scheduled_running_reqs
|
||||
self._make_cached_request_data(
|
||||
req,
|
||||
req_to_new_block_ids[req.request_id],
|
||||
req.num_computed_tokens,
|
||||
resumed_from_preemption=False,
|
||||
) for req in scheduled_running_reqs
|
||||
]
|
||||
preempted_req_ids = {req.request_id for req in preempted_reqs}
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_resumed_reqs=resumed_reqs_data,
|
||||
scheduled_running_reqs=running_reqs_data,
|
||||
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
preempted_req_ids=preempted_req_ids,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
@ -316,22 +312,26 @@ class Scheduler:
|
||||
self.finished_req_ids = set()
|
||||
return scheduler_output
|
||||
|
||||
def _make_running_request_data(
|
||||
def _make_cached_request_data(
|
||||
self,
|
||||
request: Request,
|
||||
new_block_ids: List[int],
|
||||
num_computed_tokens: int,
|
||||
) -> "RunningRequestData":
|
||||
# OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
|
||||
resumed_from_preemption: bool,
|
||||
) -> "CachedRequestData":
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
if request.request_id in self.running_reqs_data:
|
||||
req_data = self.running_reqs_data[request.request_id]
|
||||
if request.request_id in self._cached_reqs_data:
|
||||
req_data = self._cached_reqs_data[request.request_id]
|
||||
req_data.resumed_from_preemption = resumed_from_preemption
|
||||
req_data.new_block_ids = new_block_ids
|
||||
req_data.num_computed_tokens = num_computed_tokens
|
||||
else:
|
||||
req_data = RunningRequestData.from_request(request, new_block_ids,
|
||||
num_computed_tokens)
|
||||
self.running_reqs_data[request.request_id] = req_data
|
||||
req_data = CachedRequestData.from_request(request,
|
||||
resumed_from_preemption,
|
||||
new_block_ids,
|
||||
num_computed_tokens)
|
||||
self._cached_reqs_data[request.request_id] = req_data
|
||||
return req_data
|
||||
|
||||
def _try_schedule_encoder_inputs(
|
||||
@ -420,7 +420,13 @@ class Scheduler:
|
||||
# expensive operations inside the loop.
|
||||
for request in self.running:
|
||||
req_id = request.request_id
|
||||
request.num_computed_tokens += num_scheduled_tokens[req_id]
|
||||
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
|
||||
if num_tokens_scheduled == 0:
|
||||
# The request was not scheduled in this step.
|
||||
new_running.append(request)
|
||||
continue
|
||||
|
||||
request.num_computed_tokens += num_tokens_scheduled
|
||||
# When the request's num_computed_tokens catches up its num_tokens,
|
||||
# the request generates output tokens. Otherwise, we ignore the
|
||||
# sampler output for the request.
|
||||
@ -529,7 +535,7 @@ class Scheduler:
|
||||
assert request.is_finished()
|
||||
self.kv_cache_manager.free(request)
|
||||
self.encoder_cache_manager.free(request)
|
||||
self.running_reqs_data.pop(request.request_id, None)
|
||||
self._cached_reqs_data.pop(request.request_id, None)
|
||||
del self.requests[request.request_id]
|
||||
self.finished_req_ids.add(request.request_id)
|
||||
|
||||
@ -584,30 +590,13 @@ class NewRequestData:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumedRequestData:
|
||||
|
||||
req_id: str
|
||||
block_ids: List[int]
|
||||
num_computed_tokens: int
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
block_ids: List[int],
|
||||
num_computed_tokens: int,
|
||||
) -> "ResumedRequestData":
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunningRequestData:
|
||||
class CachedRequestData:
|
||||
|
||||
req_id: str
|
||||
# If resumed_from_preemption is False, new_block_ids will be appended to
|
||||
# the request's block IDs. If True, new_block_ids will be used as the
|
||||
# request's block IDs instead of appending to the existing block IDs.
|
||||
resumed_from_preemption: bool
|
||||
new_block_ids: List[int]
|
||||
num_computed_tokens: int
|
||||
|
||||
@ -615,11 +604,13 @@ class RunningRequestData:
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
resumed_from_preemption: bool,
|
||||
new_block_ids: List[int],
|
||||
num_computed_tokens: int,
|
||||
) -> "RunningRequestData":
|
||||
) -> "CachedRequestData":
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
resumed_from_preemption=resumed_from_preemption,
|
||||
new_block_ids=new_block_ids,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
)
|
||||
@ -629,14 +620,12 @@ class RunningRequestData:
|
||||
class SchedulerOutput:
|
||||
|
||||
scheduled_new_reqs: List[NewRequestData]
|
||||
scheduled_resumed_reqs: List[ResumedRequestData]
|
||||
scheduled_running_reqs: List[RunningRequestData]
|
||||
scheduled_cached_reqs: List[CachedRequestData]
|
||||
|
||||
num_scheduled_tokens: Dict[str, int]
|
||||
total_num_scheduled_tokens: int
|
||||
scheduled_encoder_inputs: Dict[str, List[int]]
|
||||
num_common_prefix_blocks: int
|
||||
|
||||
preempted_req_ids: Set[str]
|
||||
finished_req_ids: Set[str]
|
||||
free_encoder_input_ids: List[Tuple[str, int]]
|
||||
|
@ -46,6 +46,8 @@ class BlockTable:
|
||||
start: int,
|
||||
block_ids: List[int],
|
||||
) -> None:
|
||||
if not block_ids:
|
||||
return
|
||||
num_blocks = len(block_ids)
|
||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||
self.num_blocks_per_row[row_idx] = start + num_blocks
|
||||
|
@ -205,12 +205,32 @@ class GPUModelRunner:
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Remove stopped requests from the cached states.
|
||||
# Keep the states of the preempted requests.
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
output.
|
||||
|
||||
The updated states are used by the `_prepare_inputs` function to create
|
||||
the input GPU tensors for the model.
|
||||
|
||||
Returns:
|
||||
True if there is a new/resumed/paused/finished request in the batch.
|
||||
If False, we can skip copying SamplingMetadata to the GPU.
|
||||
"""
|
||||
# Remove finished requests from the cached states.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
# Remove the finished requests from the persistent batch.
|
||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||
# scheduled_req_ids overlap. This happens when a request is aborted and
|
||||
# then resubmitted with the same ID. In this case, we treat them as two
|
||||
# distinct requests - clearing the cached states for the first request
|
||||
# and handling the second as a new request.
|
||||
removed_req_indices: List[int] = []
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
req_index = self.input_batch.remove_request(req_id)
|
||||
if req_index is not None:
|
||||
removed_req_indices.append(req_index)
|
||||
|
||||
# Free the cached encoder outputs.
|
||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||
@ -220,36 +240,22 @@ class GPUModelRunner:
|
||||
if not encoder_outputs:
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Remove the requests from the persistent batch.
|
||||
stopped_req_ids = set().union(
|
||||
scheduler_output.preempted_req_ids,
|
||||
scheduler_output.finished_req_ids,
|
||||
)
|
||||
removed_req_indices: List[int] = []
|
||||
for req_id in stopped_req_ids:
|
||||
# Remove the unscheduled requests from the persistent batch.
|
||||
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
||||
# or running requests that are not scheduled in this step. We remove
|
||||
# them from the persistent batch but keep their cached states since
|
||||
# they will be scheduled again sometime in the future.
|
||||
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
||||
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
||||
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
|
||||
# NOTE(woosuk): The persistent batch optimization assumes that
|
||||
# consecutive batches contain mostly the same requests. If batches
|
||||
# have low request overlap (e.g., alternating between two distinct
|
||||
# sets of requests), this optimization becomes very inefficient.
|
||||
for req_id in unscheduled_req_ids:
|
||||
req_index = self.input_batch.remove_request(req_id)
|
||||
if req_index is not None:
|
||||
removed_req_indices.append(req_index)
|
||||
|
||||
# Update the states of the running requests.
|
||||
for req_data in scheduler_output.scheduled_running_reqs:
|
||||
req_id = req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
|
||||
# Update the num_computed_tokens.
|
||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
req_data.num_computed_tokens)
|
||||
|
||||
# Update the block table.
|
||||
num_new_blocks = len(req_data.new_block_ids)
|
||||
if num_new_blocks == 0:
|
||||
continue
|
||||
start_index = len(req_state.block_ids)
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
self.input_batch.block_table.append_row(req_index, start_index,
|
||||
req_data.new_block_ids)
|
||||
assert req_index is not None
|
||||
removed_req_indices.append(req_index)
|
||||
|
||||
req_ids_to_add: List[str] = []
|
||||
# Add new requests to the cached states.
|
||||
@ -305,14 +311,36 @@ class GPUModelRunner:
|
||||
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Update the cached states of the resumed requests.
|
||||
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
||||
req_id = res_req_data.req_id
|
||||
# Update the states of the running/resumed requests.
|
||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
||||
req_id = req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
req_state.block_ids = res_req_data.block_ids
|
||||
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
||||
req_ids_to_add.append(req_id)
|
||||
# Update the cached states.
|
||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||
if not req_data.resumed_from_preemption:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
else:
|
||||
# The request is resumed from preemption.
|
||||
# Replace the existing block IDs with the new ones.
|
||||
req_state.block_ids = req_data.new_block_ids
|
||||
|
||||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||
if req_index is None:
|
||||
# The request is not in the persistent batch.
|
||||
# The request was either preempted and resumed later, or was not
|
||||
# scheduled in the previous step and needs to be added again.
|
||||
req_ids_to_add.append(req_id)
|
||||
continue
|
||||
|
||||
# Update the persistent batch.
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
req_data.num_computed_tokens)
|
||||
start_index = len(req_state.block_ids) - len(
|
||||
req_data.new_block_ids)
|
||||
self.input_batch.block_table.append_row(req_index, start_index,
|
||||
req_data.new_block_ids)
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
@ -330,6 +358,7 @@ class GPUModelRunner:
|
||||
# Condense the batched states if there are empty indices.
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
|
||||
|
||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
@ -536,10 +565,10 @@ class GPUModelRunner:
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
)
|
||||
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
||||
# request in the batch. While we should not sample any token from this
|
||||
# partial request, we do so for simplicity. We will ignore the sampled
|
||||
# token from the partial request.
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
|
||||
# requests. While we should not sample any token from these partial
|
||||
# requests, we do so for simplicity. We will ignore the sampled
|
||||
# tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
return attn_metadata, logits_indices
|
||||
@ -601,22 +630,15 @@ class GPUModelRunner:
|
||||
|
||||
def _prepare_sampling(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
batch_changed: bool,
|
||||
) -> SamplingMetadata:
|
||||
skip_copy = True
|
||||
if (scheduler_output.finished_req_ids
|
||||
or scheduler_output.preempted_req_ids):
|
||||
skip_copy = False
|
||||
if (scheduler_output.scheduled_new_reqs
|
||||
or scheduler_output.scheduled_resumed_reqs):
|
||||
skip_copy = False
|
||||
# Create the sampling metadata.
|
||||
req_id_output_token_ids: Dict[str, List[int]] = \
|
||||
{req_id: req.output_token_ids \
|
||||
for req_id, req in self.requests.items()}
|
||||
|
||||
sampling_metadata = self.input_batch.make_sampling_metadata(
|
||||
req_id_output_token_ids, skip_copy)
|
||||
req_id_output_token_ids, skip_copy=not batch_changed)
|
||||
return sampling_metadata
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
@ -715,7 +737,7 @@ class GPUModelRunner:
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput:
|
||||
self._update_states(scheduler_output)
|
||||
batch_changed = self._update_states(scheduler_output)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
@ -778,7 +800,7 @@ class GPUModelRunner:
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self._prepare_sampling(scheduler_output)
|
||||
sampling_metadata = self._prepare_sampling(batch_changed)
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
|
Loading…
x
Reference in New Issue
Block a user