[Core] Scheduling optimization 2 (#4280)
This commit is contained in:
parent
8f2ea22bde
commit
050f285ff6
@ -563,7 +563,8 @@ def test_decode_schedule_preempted():
|
||||
assert len(output.preempted) == 2
|
||||
# Verify budgets are updated.
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 1
|
||||
# NOTE: When enable_chunk is False, num_seqs budget is not updated.
|
||||
# assert budget.num_curr_seqs == 1
|
||||
# Both should be preempted, not swapped.
|
||||
assert output.blocks_to_swap_out == {}
|
||||
# Nothing is copied.
|
||||
|
@ -395,12 +395,12 @@ class Scheduler:
|
||||
# We can have up to 1 running prefill at any given time in running
|
||||
# queue, which means we can guarantee chunk size is at least 1.
|
||||
assert num_running_tokens != 0
|
||||
num_running_seqs = seq_group.get_max_num_running_seqs()
|
||||
|
||||
running_queue.popleft()
|
||||
while not self._can_append_slots(seq_group):
|
||||
budget.subtract_num_batched_tokens(seq_group.request_id,
|
||||
num_running_tokens)
|
||||
num_running_seqs = seq_group.get_max_num_running_seqs()
|
||||
budget.subtract_num_seqs(seq_group.request_id,
|
||||
num_running_seqs)
|
||||
if curr_loras is not None and seq_group.lora_int_id > 0:
|
||||
@ -439,7 +439,13 @@ class Scheduler:
|
||||
token_chunk_size=1))
|
||||
budget.add_num_batched_tokens(seq_group.request_id,
|
||||
num_running_tokens)
|
||||
budget.add_num_seqs(seq_group.request_id, num_running_seqs)
|
||||
# OPTIMIZATION: Note that get_max_num_running_seqs is
|
||||
# expensive. For the default scheduling chase where
|
||||
# enable_chunking is False, num_seqs are updated before running
|
||||
# this method, so we don't have to update it again here.
|
||||
if enable_chunking:
|
||||
num_running_seqs = seq_group.get_max_num_running_seqs()
|
||||
budget.add_num_seqs(seq_group.request_id, num_running_seqs)
|
||||
if curr_loras is not None and seq_group.lora_int_id > 0:
|
||||
curr_loras.add(seq_group.lora_int_id)
|
||||
|
||||
|
@ -508,6 +508,11 @@ class SequenceGroup:
|
||||
return num_uncomputed_tokens
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
# Optimization. We don't need to call get_seqs if we don't need to
|
||||
# filter by states.
|
||||
if status is None:
|
||||
return len(self.seqs_dict)
|
||||
|
||||
return len(self.get_seqs(status))
|
||||
|
||||
def num_unfinished_seqs(self) -> int:
|
||||
|
Loading…
x
Reference in New Issue
Block a user