[Core] Scheduling optimization 2 (#4280)

This commit is contained in:
SangBin Cho 2024-04-23 17:02:11 +09:00 committed by GitHub
parent 8f2ea22bde
commit 050f285ff6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 3 deletions

View File

@ -563,7 +563,8 @@ def test_decode_schedule_preempted():
assert len(output.preempted) == 2
# Verify budgets are updated.
assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 1
# NOTE: When enable_chunk is False, num_seqs budget is not updated.
# assert budget.num_curr_seqs == 1
# Both should be preempted, not swapped.
assert output.blocks_to_swap_out == {}
# Nothing is copied.

View File

@ -395,12 +395,12 @@ class Scheduler:
# We can have up to 1 running prefill at any given time in running
# queue, which means we can guarantee chunk size is at least 1.
assert num_running_tokens != 0
num_running_seqs = seq_group.get_max_num_running_seqs()
running_queue.popleft()
while not self._can_append_slots(seq_group):
budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens)
num_running_seqs = seq_group.get_max_num_running_seqs()
budget.subtract_num_seqs(seq_group.request_id,
num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0:
@ -439,7 +439,13 @@ class Scheduler:
token_chunk_size=1))
budget.add_num_batched_tokens(seq_group.request_id,
num_running_tokens)
budget.add_num_seqs(seq_group.request_id, num_running_seqs)
# OPTIMIZATION: Note that get_max_num_running_seqs is
# expensive. For the default scheduling chase where
# enable_chunking is False, num_seqs are updated before running
# this method, so we don't have to update it again here.
if enable_chunking:
num_running_seqs = seq_group.get_max_num_running_seqs()
budget.add_num_seqs(seq_group.request_id, num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id)

View File

@ -508,6 +508,11 @@ class SequenceGroup:
return num_uncomputed_tokens
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
# Optimization. We don't need to call get_seqs if we don't need to
# filter by states.
if status is None:
return len(self.seqs_dict)
return len(self.get_seqs(status))
def num_unfinished_seqs(self) -> int: