[V1] Remove scheduling constraint on partial requests (#12674)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
d1ca7df84d
commit
18a88fcccc
214
tests/v1/core/test_scheduler.py
Normal file
214
tests/v1/core/test_scheduler.py
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
||||||
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.v1.core.scheduler import Scheduler
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
model: str = "facebook/opt-125m",
|
||||||
|
max_num_seqs: int = 16,
|
||||||
|
max_num_batched_tokens: int = 8192,
|
||||||
|
) -> Scheduler:
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_model_len=max_num_batched_tokens,
|
||||||
|
)
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model=model,
|
||||||
|
task="auto",
|
||||||
|
tokenizer=model,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype="float16",
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
cache_config = CacheConfig(
|
||||||
|
block_size=16,
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
swap_space=0,
|
||||||
|
cache_dtype="auto",
|
||||||
|
)
|
||||||
|
cache_config.num_gpu_blocks = 10000
|
||||||
|
return Scheduler(scheduler_config,
|
||||||
|
model_config,
|
||||||
|
cache_config,
|
||||||
|
lora_config=None)
|
||||||
|
|
||||||
|
|
||||||
|
def create_requests(
|
||||||
|
num_requests: int,
|
||||||
|
num_tokens: int = 10,
|
||||||
|
mm_positions: Optional[List[PlaceholderRange]] = None,
|
||||||
|
):
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
requests = []
|
||||||
|
for i in range(num_requests):
|
||||||
|
if mm_positions is not None:
|
||||||
|
mm_position = mm_positions[i]
|
||||||
|
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
|
||||||
|
else:
|
||||||
|
mm_position = None
|
||||||
|
mm_inputs = None
|
||||||
|
request = Request(
|
||||||
|
request_id=f"{i}",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=[i] * num_tokens,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
multi_modal_inputs=mm_inputs,
|
||||||
|
multi_modal_placeholders=mm_position,
|
||||||
|
multi_modal_hashes=None,
|
||||||
|
eos_token_id=None,
|
||||||
|
arrival_time=0,
|
||||||
|
)
|
||||||
|
requests.append(request)
|
||||||
|
return requests
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_requests():
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
requests = create_requests(num_requests=10)
|
||||||
|
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.add_request(request)
|
||||||
|
assert request.request_id in scheduler.requests
|
||||||
|
assert len(scheduler.waiting) == i + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_finish_request():
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
requests = create_requests(num_requests=10)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.finish_requests(request.request_id,
|
||||||
|
RequestStatus.FINISHED_ABORTED)
|
||||||
|
assert request.request_id not in scheduler.requests
|
||||||
|
assert len(scheduler.waiting) == 9 - i
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_unfinished_requests():
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
requests = create_requests(num_requests=10)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.finish_requests(request.request_id,
|
||||||
|
RequestStatus.FINISHED_STOPPED)
|
||||||
|
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_schedule():
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
requests = create_requests(num_requests=10)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Test initial scheduling
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == len(requests)
|
||||||
|
assert len(output.scheduled_cached_reqs) == 0
|
||||||
|
assert len(output.finished_req_ids) == 0
|
||||||
|
# Verify all requests are scheduled.
|
||||||
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||||
|
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
||||||
|
|
||||||
|
# Verify requests moved from waiting to running
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
assert len(scheduler.running) == len(requests)
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
assert scheduler.running[i] == request
|
||||||
|
|
||||||
|
|
||||||
|
def test_schedule_multimodal_requests():
|
||||||
|
scheduler = create_scheduler(model="llava-hf/llava-1.5-7b-hf")
|
||||||
|
mm_positions = [[PlaceholderRange(offset=i, length=100)]
|
||||||
|
for i in range(10)]
|
||||||
|
requests = create_requests(
|
||||||
|
num_requests=10,
|
||||||
|
num_tokens=200,
|
||||||
|
mm_positions=mm_positions,
|
||||||
|
)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == len(requests)
|
||||||
|
assert len(output.scheduled_cached_reqs) == 0
|
||||||
|
assert len(output.finished_req_ids) == 0
|
||||||
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||||
|
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
||||||
|
assert len(output.scheduled_encoder_inputs) == 10
|
||||||
|
for req_id, encoder_input in output.scheduled_encoder_inputs.items():
|
||||||
|
assert len(encoder_input) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_schedule_partial_requests():
|
||||||
|
"""Test scheduling behavior with partial requests.
|
||||||
|
|
||||||
|
This test verifies that:
|
||||||
|
1. The scheduler can handle multiple partial requests in a single step when
|
||||||
|
constrained by encoder budget.
|
||||||
|
2. A request in RUNNING state may be unscheduled in subsequent steps if
|
||||||
|
there is insufficient encoder budget.
|
||||||
|
"""
|
||||||
|
scheduler = create_scheduler(
|
||||||
|
model="llava-hf/llava-1.5-7b-hf",
|
||||||
|
max_num_batched_tokens=1024,
|
||||||
|
)
|
||||||
|
mm_positions = [[PlaceholderRange(offset=100, length=600)]
|
||||||
|
for _ in range(3)]
|
||||||
|
requests = create_requests(
|
||||||
|
num_requests=3,
|
||||||
|
num_tokens=800,
|
||||||
|
mm_positions=mm_positions,
|
||||||
|
)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == 3
|
||||||
|
assert len(output.scheduled_cached_reqs) == 0
|
||||||
|
assert len(output.finished_req_ids) == 0
|
||||||
|
|
||||||
|
assert scheduler.max_num_encoder_input_tokens == 1024
|
||||||
|
# The first request is scheduled fully.
|
||||||
|
assert output.num_scheduled_tokens[requests[0].request_id] == 800
|
||||||
|
# The second request is scheduled partially.
|
||||||
|
# The <img> tokens are not scheduled because of the encoder budget.
|
||||||
|
assert output.num_scheduled_tokens[requests[1].request_id] == 100
|
||||||
|
# The third request is also scheduled partially.
|
||||||
|
# The <img> tokens are not scheduled because of the encoder budget.
|
||||||
|
assert output.num_scheduled_tokens[requests[2].request_id] == 100
|
||||||
|
req_to_index = {
|
||||||
|
request.request_id: i
|
||||||
|
for i, request in enumerate(requests)
|
||||||
|
}
|
||||||
|
model_runner_output = ModelRunnerOutput(
|
||||||
|
req_ids=[request.request_id for request in requests],
|
||||||
|
req_id_to_index=req_to_index,
|
||||||
|
sampled_token_ids=[0] * len(requests),
|
||||||
|
logprob_token_ids_cpu=None,
|
||||||
|
logprobs_cpu=None,
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(output, model_runner_output)
|
||||||
|
|
||||||
|
# Schedule the next step.
|
||||||
|
# Only the first and second requests are scheduled.
|
||||||
|
# The third request is in the RUNNING state but not scheduled in this step
|
||||||
|
# because of the encoder budget.
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 3
|
||||||
|
assert len(output.scheduled_new_reqs) == 0
|
||||||
|
assert len(output.scheduled_cached_reqs) == 2
|
||||||
|
assert len(output.finished_req_ids) == 0
|
||||||
|
assert output.num_scheduled_tokens[requests[0].request_id] == 1
|
||||||
|
assert output.num_scheduled_tokens[requests[1].request_id] == 700
|
||||||
|
assert requests[2].request_id not in output.num_scheduled_tokens
|
@ -67,10 +67,10 @@ class Scheduler:
|
|||||||
# This is flushed at the end of each scheduling step.
|
# This is flushed at the end of each scheduling step.
|
||||||
self.finished_req_ids: Set[str] = set()
|
self.finished_req_ids: Set[str] = set()
|
||||||
|
|
||||||
# OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
|
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||||
# them at each scheduling step.
|
# them at each scheduling step.
|
||||||
# Request id -> RunningRequestData
|
# Request id -> CachedRequestData
|
||||||
self.running_reqs_data: Dict[str, RunningRequestData] = {}
|
self._cached_reqs_data: Dict[str, CachedRequestData] = {}
|
||||||
|
|
||||||
# Encoder-related.
|
# Encoder-related.
|
||||||
# Calculate encoder cache size if applicable
|
# Calculate encoder cache size if applicable
|
||||||
@ -115,17 +115,8 @@ class Scheduler:
|
|||||||
encoder_budget = self.max_num_encoder_input_tokens
|
encoder_budget = self.max_num_encoder_input_tokens
|
||||||
|
|
||||||
# First, schedule the RUNNING requests.
|
# First, schedule the RUNNING requests.
|
||||||
# NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be
|
|
||||||
# in the "partial" state, where the request has some tokens computed
|
|
||||||
# but not all. The constraint is due to the persistent batch in the
|
|
||||||
# V1 model runner.
|
|
||||||
# TODO(woosuk): Remove this constraint after refactoring model runner.
|
|
||||||
has_partial_request = False
|
|
||||||
req_index = 0
|
req_index = 0
|
||||||
while req_index < len(self.running):
|
while req_index < len(self.running) and token_budget > 0:
|
||||||
# Only the last request in the RUNNING queue can be "partial".
|
|
||||||
assert not has_partial_request
|
|
||||||
assert token_budget > 0
|
|
||||||
request = self.running[req_index]
|
request = self.running[req_index]
|
||||||
num_new_tokens = request.num_tokens - request.num_computed_tokens
|
num_new_tokens = request.num_tokens - request.num_computed_tokens
|
||||||
num_new_tokens = min(num_new_tokens, token_budget)
|
num_new_tokens = min(num_new_tokens, token_budget)
|
||||||
@ -137,7 +128,14 @@ class Scheduler:
|
|||||||
request.num_computed_tokens,
|
request.num_computed_tokens,
|
||||||
num_new_tokens,
|
num_new_tokens,
|
||||||
encoder_budget))
|
encoder_budget))
|
||||||
assert num_new_tokens > 0
|
if num_new_tokens == 0:
|
||||||
|
# The request cannot be scheduled because the encoder budget
|
||||||
|
# or the encoder cache is exhausted.
|
||||||
|
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
||||||
|
# we do not strictly follow the FCFS scheduling policy and
|
||||||
|
# allow the lower-priority requests to be scheduled.
|
||||||
|
req_index += 1
|
||||||
|
continue
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
@ -172,8 +170,6 @@ class Scheduler:
|
|||||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
req_index += 1
|
req_index += 1
|
||||||
has_partial_request = (request.num_computed_tokens + num_new_tokens
|
|
||||||
< request.num_tokens)
|
|
||||||
|
|
||||||
# Encoder-related.
|
# Encoder-related.
|
||||||
if encoder_inputs_to_schedule:
|
if encoder_inputs_to_schedule:
|
||||||
@ -186,13 +182,9 @@ class Scheduler:
|
|||||||
|
|
||||||
# Next, schedule the WAITING requests.
|
# Next, schedule the WAITING requests.
|
||||||
if not preempted_reqs:
|
if not preempted_reqs:
|
||||||
while self.waiting:
|
while self.waiting and token_budget > 0:
|
||||||
if has_partial_request:
|
|
||||||
break
|
|
||||||
if len(self.running) == self.max_num_running_reqs:
|
if len(self.running) == self.max_num_running_reqs:
|
||||||
break
|
break
|
||||||
if token_budget == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
request = self.waiting[0]
|
request = self.waiting[0]
|
||||||
# Get already-cached tokens.
|
# Get already-cached tokens.
|
||||||
@ -249,8 +241,6 @@ class Scheduler:
|
|||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
request.status = RequestStatus.RUNNING
|
request.status = RequestStatus.RUNNING
|
||||||
request.num_computed_tokens = num_computed_tokens
|
request.num_computed_tokens = num_computed_tokens
|
||||||
has_partial_request = (num_computed_tokens + num_new_tokens
|
|
||||||
< request.num_tokens)
|
|
||||||
|
|
||||||
# Encoder-related.
|
# Encoder-related.
|
||||||
if encoder_inputs_to_schedule:
|
if encoder_inputs_to_schedule:
|
||||||
@ -266,8 +256,11 @@ class Scheduler:
|
|||||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||||
assert token_budget >= 0
|
assert token_budget >= 0
|
||||||
assert len(self.running) <= self.max_num_running_reqs
|
assert len(self.running) <= self.max_num_running_reqs
|
||||||
|
# Since some requests in the RUNNING queue may not be scheduled in
|
||||||
|
# this step, the total number of scheduled requests can be smaller than
|
||||||
|
# len(self.running).
|
||||||
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
|
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
|
||||||
len(scheduled_running_reqs) == len(self.running))
|
len(scheduled_running_reqs) <= len(self.running))
|
||||||
|
|
||||||
# Get the longest common prefix among all requests in the running queue.
|
# Get the longest common prefix among all requests in the running queue.
|
||||||
# This can be potentially used for cascade attention.
|
# This can be potentially used for cascade attention.
|
||||||
@ -286,25 +279,28 @@ class Scheduler:
|
|||||||
for req in scheduled_new_reqs
|
for req in scheduled_new_reqs
|
||||||
]
|
]
|
||||||
resumed_reqs_data = [
|
resumed_reqs_data = [
|
||||||
ResumedRequestData.from_request(
|
self._make_cached_request_data(
|
||||||
req, req_to_new_block_ids[req.request_id],
|
req,
|
||||||
req.num_computed_tokens) for req in scheduled_resumed_reqs
|
req_to_new_block_ids[req.request_id],
|
||||||
|
req.num_computed_tokens,
|
||||||
|
resumed_from_preemption=True,
|
||||||
|
) for req in scheduled_resumed_reqs
|
||||||
]
|
]
|
||||||
running_reqs_data = [
|
running_reqs_data = [
|
||||||
self._make_running_request_data(
|
self._make_cached_request_data(
|
||||||
req, req_to_new_block_ids[req.request_id],
|
req,
|
||||||
req.num_computed_tokens) for req in scheduled_running_reqs
|
req_to_new_block_ids[req.request_id],
|
||||||
|
req.num_computed_tokens,
|
||||||
|
resumed_from_preemption=False,
|
||||||
|
) for req in scheduled_running_reqs
|
||||||
]
|
]
|
||||||
preempted_req_ids = {req.request_id for req in preempted_reqs}
|
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=new_reqs_data,
|
scheduled_new_reqs=new_reqs_data,
|
||||||
scheduled_resumed_reqs=resumed_reqs_data,
|
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
|
||||||
scheduled_running_reqs=running_reqs_data,
|
|
||||||
num_scheduled_tokens=num_scheduled_tokens,
|
num_scheduled_tokens=num_scheduled_tokens,
|
||||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||||
preempted_req_ids=preempted_req_ids,
|
|
||||||
# finished_req_ids is an existing state in the scheduler,
|
# finished_req_ids is an existing state in the scheduler,
|
||||||
# instead of being newly scheduled in this step.
|
# instead of being newly scheduled in this step.
|
||||||
# It contains the request IDs that are finished in between
|
# It contains the request IDs that are finished in between
|
||||||
@ -316,22 +312,26 @@ class Scheduler:
|
|||||||
self.finished_req_ids = set()
|
self.finished_req_ids = set()
|
||||||
return scheduler_output
|
return scheduler_output
|
||||||
|
|
||||||
def _make_running_request_data(
|
def _make_cached_request_data(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: Request,
|
||||||
new_block_ids: List[int],
|
new_block_ids: List[int],
|
||||||
num_computed_tokens: int,
|
num_computed_tokens: int,
|
||||||
) -> "RunningRequestData":
|
resumed_from_preemption: bool,
|
||||||
# OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
|
) -> "CachedRequestData":
|
||||||
|
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||||
# them at each scheduling step.
|
# them at each scheduling step.
|
||||||
if request.request_id in self.running_reqs_data:
|
if request.request_id in self._cached_reqs_data:
|
||||||
req_data = self.running_reqs_data[request.request_id]
|
req_data = self._cached_reqs_data[request.request_id]
|
||||||
|
req_data.resumed_from_preemption = resumed_from_preemption
|
||||||
req_data.new_block_ids = new_block_ids
|
req_data.new_block_ids = new_block_ids
|
||||||
req_data.num_computed_tokens = num_computed_tokens
|
req_data.num_computed_tokens = num_computed_tokens
|
||||||
else:
|
else:
|
||||||
req_data = RunningRequestData.from_request(request, new_block_ids,
|
req_data = CachedRequestData.from_request(request,
|
||||||
num_computed_tokens)
|
resumed_from_preemption,
|
||||||
self.running_reqs_data[request.request_id] = req_data
|
new_block_ids,
|
||||||
|
num_computed_tokens)
|
||||||
|
self._cached_reqs_data[request.request_id] = req_data
|
||||||
return req_data
|
return req_data
|
||||||
|
|
||||||
def _try_schedule_encoder_inputs(
|
def _try_schedule_encoder_inputs(
|
||||||
@ -420,7 +420,13 @@ class Scheduler:
|
|||||||
# expensive operations inside the loop.
|
# expensive operations inside the loop.
|
||||||
for request in self.running:
|
for request in self.running:
|
||||||
req_id = request.request_id
|
req_id = request.request_id
|
||||||
request.num_computed_tokens += num_scheduled_tokens[req_id]
|
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
|
||||||
|
if num_tokens_scheduled == 0:
|
||||||
|
# The request was not scheduled in this step.
|
||||||
|
new_running.append(request)
|
||||||
|
continue
|
||||||
|
|
||||||
|
request.num_computed_tokens += num_tokens_scheduled
|
||||||
# When the request's num_computed_tokens catches up its num_tokens,
|
# When the request's num_computed_tokens catches up its num_tokens,
|
||||||
# the request generates output tokens. Otherwise, we ignore the
|
# the request generates output tokens. Otherwise, we ignore the
|
||||||
# sampler output for the request.
|
# sampler output for the request.
|
||||||
@ -529,7 +535,7 @@ class Scheduler:
|
|||||||
assert request.is_finished()
|
assert request.is_finished()
|
||||||
self.kv_cache_manager.free(request)
|
self.kv_cache_manager.free(request)
|
||||||
self.encoder_cache_manager.free(request)
|
self.encoder_cache_manager.free(request)
|
||||||
self.running_reqs_data.pop(request.request_id, None)
|
self._cached_reqs_data.pop(request.request_id, None)
|
||||||
del self.requests[request.request_id]
|
del self.requests[request.request_id]
|
||||||
self.finished_req_ids.add(request.request_id)
|
self.finished_req_ids.add(request.request_id)
|
||||||
|
|
||||||
@ -584,30 +590,13 @@ class NewRequestData:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ResumedRequestData:
|
class CachedRequestData:
|
||||||
|
|
||||||
req_id: str
|
|
||||||
block_ids: List[int]
|
|
||||||
num_computed_tokens: int
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_request(
|
|
||||||
cls,
|
|
||||||
request: Request,
|
|
||||||
block_ids: List[int],
|
|
||||||
num_computed_tokens: int,
|
|
||||||
) -> "ResumedRequestData":
|
|
||||||
return cls(
|
|
||||||
req_id=request.request_id,
|
|
||||||
block_ids=block_ids,
|
|
||||||
num_computed_tokens=num_computed_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RunningRequestData:
|
|
||||||
|
|
||||||
req_id: str
|
req_id: str
|
||||||
|
# If resumed_from_preemption is False, new_block_ids will be appended to
|
||||||
|
# the request's block IDs. If True, new_block_ids will be used as the
|
||||||
|
# request's block IDs instead of appending to the existing block IDs.
|
||||||
|
resumed_from_preemption: bool
|
||||||
new_block_ids: List[int]
|
new_block_ids: List[int]
|
||||||
num_computed_tokens: int
|
num_computed_tokens: int
|
||||||
|
|
||||||
@ -615,11 +604,13 @@ class RunningRequestData:
|
|||||||
def from_request(
|
def from_request(
|
||||||
cls,
|
cls,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
resumed_from_preemption: bool,
|
||||||
new_block_ids: List[int],
|
new_block_ids: List[int],
|
||||||
num_computed_tokens: int,
|
num_computed_tokens: int,
|
||||||
) -> "RunningRequestData":
|
) -> "CachedRequestData":
|
||||||
return cls(
|
return cls(
|
||||||
req_id=request.request_id,
|
req_id=request.request_id,
|
||||||
|
resumed_from_preemption=resumed_from_preemption,
|
||||||
new_block_ids=new_block_ids,
|
new_block_ids=new_block_ids,
|
||||||
num_computed_tokens=num_computed_tokens,
|
num_computed_tokens=num_computed_tokens,
|
||||||
)
|
)
|
||||||
@ -629,14 +620,12 @@ class RunningRequestData:
|
|||||||
class SchedulerOutput:
|
class SchedulerOutput:
|
||||||
|
|
||||||
scheduled_new_reqs: List[NewRequestData]
|
scheduled_new_reqs: List[NewRequestData]
|
||||||
scheduled_resumed_reqs: List[ResumedRequestData]
|
scheduled_cached_reqs: List[CachedRequestData]
|
||||||
scheduled_running_reqs: List[RunningRequestData]
|
|
||||||
|
|
||||||
num_scheduled_tokens: Dict[str, int]
|
num_scheduled_tokens: Dict[str, int]
|
||||||
total_num_scheduled_tokens: int
|
total_num_scheduled_tokens: int
|
||||||
scheduled_encoder_inputs: Dict[str, List[int]]
|
scheduled_encoder_inputs: Dict[str, List[int]]
|
||||||
num_common_prefix_blocks: int
|
num_common_prefix_blocks: int
|
||||||
|
|
||||||
preempted_req_ids: Set[str]
|
|
||||||
finished_req_ids: Set[str]
|
finished_req_ids: Set[str]
|
||||||
free_encoder_input_ids: List[Tuple[str, int]]
|
free_encoder_input_ids: List[Tuple[str, int]]
|
||||||
|
@ -46,6 +46,8 @@ class BlockTable:
|
|||||||
start: int,
|
start: int,
|
||||||
block_ids: List[int],
|
block_ids: List[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if not block_ids:
|
||||||
|
return
|
||||||
num_blocks = len(block_ids)
|
num_blocks = len(block_ids)
|
||||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||||
self.num_blocks_per_row[row_idx] = start + num_blocks
|
self.num_blocks_per_row[row_idx] = start + num_blocks
|
||||||
|
@ -205,12 +205,32 @@ class GPUModelRunner:
|
|||||||
pin_memory=self.pin_memory)
|
pin_memory=self.pin_memory)
|
||||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||||
# Remove stopped requests from the cached states.
|
"""Update the cached states and the persistent batch with the scheduler
|
||||||
# Keep the states of the preempted requests.
|
output.
|
||||||
|
|
||||||
|
The updated states are used by the `_prepare_inputs` function to create
|
||||||
|
the input GPU tensors for the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if there is a new/resumed/paused/finished request in the batch.
|
||||||
|
If False, we can skip copying SamplingMetadata to the GPU.
|
||||||
|
"""
|
||||||
|
# Remove finished requests from the cached states.
|
||||||
for req_id in scheduler_output.finished_req_ids:
|
for req_id in scheduler_output.finished_req_ids:
|
||||||
self.requests.pop(req_id, None)
|
self.requests.pop(req_id, None)
|
||||||
self.encoder_cache.pop(req_id, None)
|
self.encoder_cache.pop(req_id, None)
|
||||||
|
# Remove the finished requests from the persistent batch.
|
||||||
|
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||||
|
# scheduled_req_ids overlap. This happens when a request is aborted and
|
||||||
|
# then resubmitted with the same ID. In this case, we treat them as two
|
||||||
|
# distinct requests - clearing the cached states for the first request
|
||||||
|
# and handling the second as a new request.
|
||||||
|
removed_req_indices: List[int] = []
|
||||||
|
for req_id in scheduler_output.finished_req_ids:
|
||||||
|
req_index = self.input_batch.remove_request(req_id)
|
||||||
|
if req_index is not None:
|
||||||
|
removed_req_indices.append(req_index)
|
||||||
|
|
||||||
# Free the cached encoder outputs.
|
# Free the cached encoder outputs.
|
||||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||||
@ -220,36 +240,22 @@ class GPUModelRunner:
|
|||||||
if not encoder_outputs:
|
if not encoder_outputs:
|
||||||
self.encoder_cache.pop(req_id, None)
|
self.encoder_cache.pop(req_id, None)
|
||||||
|
|
||||||
# Remove the requests from the persistent batch.
|
# Remove the unscheduled requests from the persistent batch.
|
||||||
stopped_req_ids = set().union(
|
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
||||||
scheduler_output.preempted_req_ids,
|
# or running requests that are not scheduled in this step. We remove
|
||||||
scheduler_output.finished_req_ids,
|
# them from the persistent batch but keep their cached states since
|
||||||
)
|
# they will be scheduled again sometime in the future.
|
||||||
removed_req_indices: List[int] = []
|
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
||||||
for req_id in stopped_req_ids:
|
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
||||||
|
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
|
||||||
|
# NOTE(woosuk): The persistent batch optimization assumes that
|
||||||
|
# consecutive batches contain mostly the same requests. If batches
|
||||||
|
# have low request overlap (e.g., alternating between two distinct
|
||||||
|
# sets of requests), this optimization becomes very inefficient.
|
||||||
|
for req_id in unscheduled_req_ids:
|
||||||
req_index = self.input_batch.remove_request(req_id)
|
req_index = self.input_batch.remove_request(req_id)
|
||||||
if req_index is not None:
|
assert req_index is not None
|
||||||
removed_req_indices.append(req_index)
|
removed_req_indices.append(req_index)
|
||||||
|
|
||||||
# Update the states of the running requests.
|
|
||||||
for req_data in scheduler_output.scheduled_running_reqs:
|
|
||||||
req_id = req_data.req_id
|
|
||||||
req_state = self.requests[req_id]
|
|
||||||
req_index = self.input_batch.req_id_to_index[req_id]
|
|
||||||
|
|
||||||
# Update the num_computed_tokens.
|
|
||||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
|
||||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
|
||||||
req_data.num_computed_tokens)
|
|
||||||
|
|
||||||
# Update the block table.
|
|
||||||
num_new_blocks = len(req_data.new_block_ids)
|
|
||||||
if num_new_blocks == 0:
|
|
||||||
continue
|
|
||||||
start_index = len(req_state.block_ids)
|
|
||||||
req_state.block_ids.extend(req_data.new_block_ids)
|
|
||||||
self.input_batch.block_table.append_row(req_index, start_index,
|
|
||||||
req_data.new_block_ids)
|
|
||||||
|
|
||||||
req_ids_to_add: List[str] = []
|
req_ids_to_add: List[str] = []
|
||||||
# Add new requests to the cached states.
|
# Add new requests to the cached states.
|
||||||
@ -305,14 +311,36 @@ class GPUModelRunner:
|
|||||||
|
|
||||||
req_ids_to_add.append(req_id)
|
req_ids_to_add.append(req_id)
|
||||||
|
|
||||||
# Update the cached states of the resumed requests.
|
# Update the states of the running/resumed requests.
|
||||||
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
for req_data in scheduler_output.scheduled_cached_reqs:
|
||||||
req_id = res_req_data.req_id
|
req_id = req_data.req_id
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
|
|
||||||
req_state.block_ids = res_req_data.block_ids
|
# Update the cached states.
|
||||||
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||||
req_ids_to_add.append(req_id)
|
if not req_data.resumed_from_preemption:
|
||||||
|
# Append the new blocks to the existing block IDs.
|
||||||
|
req_state.block_ids.extend(req_data.new_block_ids)
|
||||||
|
else:
|
||||||
|
# The request is resumed from preemption.
|
||||||
|
# Replace the existing block IDs with the new ones.
|
||||||
|
req_state.block_ids = req_data.new_block_ids
|
||||||
|
|
||||||
|
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||||
|
if req_index is None:
|
||||||
|
# The request is not in the persistent batch.
|
||||||
|
# The request was either preempted and resumed later, or was not
|
||||||
|
# scheduled in the previous step and needs to be added again.
|
||||||
|
req_ids_to_add.append(req_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update the persistent batch.
|
||||||
|
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||||
|
req_data.num_computed_tokens)
|
||||||
|
start_index = len(req_state.block_ids) - len(
|
||||||
|
req_data.new_block_ids)
|
||||||
|
self.input_batch.block_table.append_row(req_index, start_index,
|
||||||
|
req_data.new_block_ids)
|
||||||
|
|
||||||
# Add the new or resumed requests to the persistent batch.
|
# Add the new or resumed requests to the persistent batch.
|
||||||
# The smaller empty indices are filled first.
|
# The smaller empty indices are filled first.
|
||||||
@ -330,6 +358,7 @@ class GPUModelRunner:
|
|||||||
# Condense the batched states if there are empty indices.
|
# Condense the batched states if there are empty indices.
|
||||||
if removed_req_indices:
|
if removed_req_indices:
|
||||||
self.input_batch.condense(removed_req_indices)
|
self.input_batch.condense(removed_req_indices)
|
||||||
|
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
|
||||||
|
|
||||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
@ -536,10 +565,10 @@ class GPUModelRunner:
|
|||||||
prefix_kv_lens=prefix_kv_lens,
|
prefix_kv_lens=prefix_kv_lens,
|
||||||
suffix_kv_lens=suffix_kv_lens,
|
suffix_kv_lens=suffix_kv_lens,
|
||||||
)
|
)
|
||||||
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
|
||||||
# request in the batch. While we should not sample any token from this
|
# requests. While we should not sample any token from these partial
|
||||||
# partial request, we do so for simplicity. We will ignore the sampled
|
# requests, we do so for simplicity. We will ignore the sampled
|
||||||
# token from the partial request.
|
# tokens from the partial requests.
|
||||||
# TODO: Support prompt logprobs.
|
# TODO: Support prompt logprobs.
|
||||||
logits_indices = query_start_loc[1:] - 1
|
logits_indices = query_start_loc[1:] - 1
|
||||||
return attn_metadata, logits_indices
|
return attn_metadata, logits_indices
|
||||||
@ -601,22 +630,15 @@ class GPUModelRunner:
|
|||||||
|
|
||||||
def _prepare_sampling(
|
def _prepare_sampling(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
batch_changed: bool,
|
||||||
) -> SamplingMetadata:
|
) -> SamplingMetadata:
|
||||||
skip_copy = True
|
|
||||||
if (scheduler_output.finished_req_ids
|
|
||||||
or scheduler_output.preempted_req_ids):
|
|
||||||
skip_copy = False
|
|
||||||
if (scheduler_output.scheduled_new_reqs
|
|
||||||
or scheduler_output.scheduled_resumed_reqs):
|
|
||||||
skip_copy = False
|
|
||||||
# Create the sampling metadata.
|
# Create the sampling metadata.
|
||||||
req_id_output_token_ids: Dict[str, List[int]] = \
|
req_id_output_token_ids: Dict[str, List[int]] = \
|
||||||
{req_id: req.output_token_ids \
|
{req_id: req.output_token_ids \
|
||||||
for req_id, req in self.requests.items()}
|
for req_id, req in self.requests.items()}
|
||||||
|
|
||||||
sampling_metadata = self.input_batch.make_sampling_metadata(
|
sampling_metadata = self.input_batch.make_sampling_metadata(
|
||||||
req_id_output_token_ids, skip_copy)
|
req_id_output_token_ids, skip_copy=not batch_changed)
|
||||||
return sampling_metadata
|
return sampling_metadata
|
||||||
|
|
||||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||||
@ -715,7 +737,7 @@ class GPUModelRunner:
|
|||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
self._update_states(scheduler_output)
|
batch_changed = self._update_states(scheduler_output)
|
||||||
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
# Run the multimodal encoder if any.
|
# Run the multimodal encoder if any.
|
||||||
@ -778,7 +800,7 @@ class GPUModelRunner:
|
|||||||
logits = self.model.compute_logits(hidden_states, None)
|
logits = self.model.compute_logits(hidden_states, None)
|
||||||
|
|
||||||
# Sample the next token and get logprobs if needed.
|
# Sample the next token and get logprobs if needed.
|
||||||
sampling_metadata = self._prepare_sampling(scheduler_output)
|
sampling_metadata = self._prepare_sampling(batch_changed)
|
||||||
sampler_output = self.model.sample(
|
sampler_output = self.model.sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user