[V1][Scheduler] Avoid calling _try_schedule_encoder_inputs for every request (#15778)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-03-30 14:10:42 -07:00 committed by GitHub
parent 70fedd0f79
commit 9b459eca88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 30 deletions

View File

@ -171,19 +171,23 @@ class Scheduler(SchedulerInterface):
assert num_new_tokens > 0 assert num_new_tokens > 0
# Schedule encoder inputs. # Schedule encoder inputs.
encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = ( if request.has_encoder_inputs:
self._try_schedule_encoder_inputs(request, (encoder_inputs_to_schedule, num_new_tokens,
request.num_computed_tokens, new_encoder_budget) = self._try_schedule_encoder_inputs(
num_new_tokens, request, request.num_computed_tokens, num_new_tokens,
encoder_budget)) encoder_budget)
if num_new_tokens == 0: if num_new_tokens == 0:
# The request cannot be scheduled because the encoder budget # The request cannot be scheduled because the encoder budget
# or the encoder cache is exhausted. # or the encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`, # NOTE(woosuk): By using `continue` instead of `break` here,
# we do not strictly follow the FCFS scheduling policy and # we intentionally relax the strict FCFS scheduling policy
# allow the lower-priority requests to be scheduled. # to allow lower-priority requests to be scheduled when a
req_index += 1 # higher-priority request is blocked by encoder constraints.
continue req_index += 1
continue
else:
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
while True: while True:
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
@ -318,13 +322,17 @@ class Scheduler(SchedulerInterface):
assert num_new_tokens > 0 assert num_new_tokens > 0
# Schedule encoder inputs. # Schedule encoder inputs.
(encoder_inputs_to_schedule, num_new_tokens, if request.has_encoder_inputs:
new_encoder_budget) = self._try_schedule_encoder_inputs( (encoder_inputs_to_schedule, num_new_tokens,
request, num_computed_tokens, num_new_tokens, new_encoder_budget) = self._try_schedule_encoder_inputs(
encoder_budget) request, num_computed_tokens, num_new_tokens,
if num_new_tokens == 0: encoder_budget)
# The request cannot be scheduled. if num_new_tokens == 0:
break # The request cannot be scheduled.
break
else:
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_blocks) request, num_new_tokens, computed_blocks)
@ -506,9 +514,6 @@ class Scheduler(SchedulerInterface):
limitations, the method adjusts `num_new_tokens` to schedule only the limitations, the method adjusts `num_new_tokens` to schedule only the
decoder tokens up to just before the unschedulable encoder input. decoder tokens up to just before the unschedulable encoder input.
""" """
if not request.has_encoder_inputs():
return [], num_new_tokens, encoder_budget
encoder_inputs_to_schedule: list[int] = [] encoder_inputs_to_schedule: list[int] = []
mm_positions = request.mm_positions mm_positions = request.mm_positions
assert mm_positions is not None assert mm_positions is not None

View File

@ -59,6 +59,8 @@ class Request:
self.mm_positions = multi_modal_placeholders or [] self.mm_positions = multi_modal_placeholders or []
self.mm_inputs = multi_modal_inputs or [] self.mm_inputs = multi_modal_inputs or []
self.mm_hashes: list[str] = multi_modal_hashes or [] self.mm_hashes: list[str] = multi_modal_hashes or []
self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# Sanity check # Sanity check
assert len(self.mm_inputs) == len(self.mm_positions) assert len(self.mm_inputs) == len(self.mm_positions)
@ -117,13 +119,6 @@ class Request:
def get_finished_reason(self) -> Union[FinishReason, None]: def get_finished_reason(self) -> Union[FinishReason, None]:
return RequestStatus.get_finished_reason(self.status) return RequestStatus.get_finished_reason(self.status)
def has_encoder_inputs(self) -> bool:
return len(self.mm_inputs) > 0
@property
def num_encoder_inputs(self) -> int:
return len(self.mm_positions)
def get_num_encoder_tokens(self, input_id: int) -> int: def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_positions) assert input_id < len(self.mm_positions)
num_tokens = self.mm_positions[input_id]["length"] num_tokens = self.mm_positions[input_id]["length"]