[core][scheduler] simplify and improve scheduler (#6867)

This commit is contained in:
youkaichao 2024-07-31 23:51:09 -07:00 committed by GitHub
parent 3c10591ef2
commit c8a7e93273
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 112 additions and 214 deletions

View File

@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
# Allow only 2 sequences of ~128 tokens in worst case.
# Note 16 = 128/block_size
"num_gpu_blocks_override": 2 * (16 + 1),
"num_gpu_blocks_override": 2 * (16 + 2),
}
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{

View File

@ -1,13 +1,12 @@
import time
from collections import deque
from typing import Deque, List, Set, Tuple
from typing import List, Set, Tuple
from unittest.mock import MagicMock
import pytest # noqa
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.policy import PolicyFactory
from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, SequenceGroup, SequenceStatus
@ -348,10 +347,10 @@ def test_prefill_schedule_max_prompt_len():
"""
scheduler = initialize_scheduler(max_model_len=30)
_, seq_group = create_dummy_prompt("0", prompt_length=60)
waiting = deque([seq_group])
scheduler.add_seq_group(seq_group)
budget = create_token_budget()
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
output = scheduler._schedule_prefills(budget, None)
remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 1
assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0
@ -364,15 +363,14 @@ def test_prefill_schedule_token_budget():
Test token budget respected.
"""
scheduler = initialize_scheduler()
waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget(token_budget=0)
for i in range(2):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group)
scheduler.add_seq_group(seq_group)
# 0 token budget == nothing is scheduled.
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
output = scheduler._schedule_prefills(budget, None)
remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0
@ -381,8 +379,8 @@ def test_prefill_schedule_token_budget():
# 60 token budget == 1 request scheduled.
budget = create_token_budget(token_budget=60)
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
output = scheduler._schedule_prefills(budget, None)
remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 1
assert budget.num_batched_tokens == 60
@ -391,14 +389,13 @@ def test_prefill_schedule_token_budget():
# Test when current_batched_tokens respected.
scheduler = initialize_scheduler()
waiting = deque()
budget = create_token_budget(token_budget=60)
add_token_budget(budget, 30, 0)
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
# Cannot schedule a prompt that doesn't fit the budget.
waiting.append(seq_group)
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
scheduler.add_seq_group(seq_group)
output = scheduler._schedule_prefills(budget, None)
remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 30
@ -406,8 +403,8 @@ def test_prefill_schedule_token_budget():
assert len(remaining_waiting) == 1
budget = create_token_budget(token_budget=90)
add_token_budget(budget, 30, 0)
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
output = scheduler._schedule_prefills(budget, None)
remaining_waiting = scheduler.waiting
assert len(output.seq_groups) == 1
assert budget.num_batched_tokens == 90
assert budget.num_curr_seqs == 1
@ -419,13 +416,12 @@ def test_prefill_schedule_max_seqs():
Test max seq respected.
"""
scheduler = initialize_scheduler()
waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget(max_num_seqs=2)
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group)
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
scheduler.add_seq_group(seq_group)
output = scheduler._schedule_prefills(budget, None)
remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 2
assert budget.num_batched_tokens == 120
@ -433,13 +429,13 @@ def test_prefill_schedule_max_seqs():
assert len(remaining_waiting) == 1
# Verify curr_num_seqs respected.
waiting = deque()
scheduler.waiting = deque()
budget = create_token_budget(max_num_seqs=2)
add_token_budget(budget, 0, 2)
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group)
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
scheduler.add_seq_group(seq_group)
output = scheduler._schedule_prefills(budget, None)
remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0
@ -453,7 +449,6 @@ def test_prefill_schedule_max_lora():
"""
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
scheduler = initialize_scheduler(lora_config=lora_config)
waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget(token_budget=120)
curr_loras: Set[int] = set()
for i in range(2):
@ -463,7 +458,7 @@ def test_prefill_schedule_max_lora():
lora_name=str(i),
lora_int_id=i + 1,
lora_path="abc"))
waiting.append(seq_group)
scheduler.add_seq_group(seq_group)
# Add two more requests to verify lora is prioritized.
# 0: Lora, 1: Lora, 2: regular, 3: regular
# In the first iteration, index 0, 2 is scheduled.
@ -471,10 +466,10 @@ def test_prefill_schedule_max_lora():
# prioritized. Verify that.
for i in range(2, 4):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group)
scheduler.add_seq_group(seq_group)
# Schedule 2 requests (0 and 2)
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, curr_loras)
output = scheduler._schedule_prefills(budget, curr_loras)
remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 2
assert budget.num_batched_tokens == 120
@ -485,8 +480,8 @@ def test_prefill_schedule_max_lora():
# Reset curr_loras so that it can be scheduled.
curr_loras = set()
budget = create_token_budget(token_budget=60)
remaining_waiting, output = scheduler._schedule_prefills(
remaining_waiting, budget, curr_loras)
output = scheduler._schedule_prefills(budget, curr_loras)
remaining_waiting = scheduler.waiting
assert len(output.seq_groups) == 1
assert output.seq_groups[0].seq_group.request_id == "1"
assert len(remaining_waiting) == 1
@ -499,31 +494,29 @@ def test_prefill_schedule_no_block_manager_capacity():
Test sequence cannot be scheduled due to block manager has no capacity.
"""
scheduler = initialize_scheduler()
waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget()
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group)
scheduler.add_seq_group(seq_group)
scheduler.block_manager.can_allocate = MagicMock()
scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
remainig_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
output = scheduler._schedule_prefills(budget, None)
remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0
assert budget.num_curr_seqs == 0
assert len(remainig_waiting) == 3
assert len(remaining_waiting) == 3
scheduler = initialize_scheduler()
waiting = deque()
budget = create_token_budget()
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group)
scheduler.add_seq_group(seq_group)
scheduler.block_manager.can_allocate = MagicMock()
scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
output = scheduler._schedule_prefills(budget, None)
remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 3
assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0
@ -536,14 +529,12 @@ def test_decode_schedule_preempted():
Test decodes cannot be scheduled and preempted.
"""
scheduler = initialize_scheduler()
running: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
running.append(seq_group)
scheduler._add_seq_group_to_running(seq_group)
scheduler.block_manager.can_append_slots = MagicMock()
def cannot_append_second_group(seq_group, num_lookahead_slots):
@ -555,8 +546,8 @@ def test_decode_schedule_preempted():
# 1 cannot be scheduled, and the lowest priority (request 2)
# should be preempted. 1 will also be preempted.
budget = create_token_budget()
remainig_running, output = scheduler._schedule_running(
running, budget, curr_loras, policy)
output = scheduler._schedule_running(budget, curr_loras)
remainig_running = scheduler.running
assert len(remainig_running) == 0
assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0
@ -577,14 +568,12 @@ def test_decode_swap_beam_search():
Test best_of > 1 swap out blocks
"""
scheduler = initialize_scheduler()
running: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
budget = create_token_budget()
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
running.append(seq_group)
scheduler._add_seq_group_to_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs())
@ -603,8 +592,8 @@ def test_decode_swap_beam_search():
expected_swap_mapping = [("5", "7")]
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
remainig_running, output = scheduler._schedule_running(
running, budget, curr_loras, policy)
output = scheduler._schedule_running(budget, curr_loras)
remainig_running = scheduler.running
assert len(remainig_running) == 0
assert len(output.decode_seq_groups) == 2
assert len(output.prefill_seq_groups) == 0
@ -628,20 +617,18 @@ def test_schedule_decode_blocks_to_copy_update():
"""
scheduler = initialize_scheduler()
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
running: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
running.append(seq_group)
scheduler._add_seq_group_to_running(seq_group)
# The last request should be swapped out.
scheduler.block_manager.append_slots = MagicMock()
scheduler.block_manager.append_slots.return_value = [(2, 3)]
budget = create_token_budget()
remaining_running, output = scheduler._schedule_running(
running, budget, curr_loras, policy)
output = scheduler._schedule_running(budget, curr_loras)
remaining_running = scheduler.running
assert len(remaining_running) == 0
assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0
@ -656,19 +643,17 @@ def test_schedule_decode_blocks_to_copy_update():
def test_schedule_swapped_simple():
scheduler = initialize_scheduler()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out: List[Tuple[int, int]] = []
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
scheduler._add_seq_group_to_swapped(seq_group)
budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 0
assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 2
@ -683,8 +668,6 @@ def test_schedule_swapped_simple():
def test_schedule_swapped_max_token_budget():
scheduler = initialize_scheduler()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out: List[Tuple[int, int]] = []
for _ in range(2):
@ -692,11 +675,11 @@ def test_schedule_swapped_max_token_budget():
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
scheduler._add_seq_group_to_swapped(seq_group)
budget = create_token_budget(token_budget=1)
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 1
assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 2
@ -706,8 +689,8 @@ def test_schedule_swapped_max_token_budget():
# Verify num_batched_tokens are respected.
budget = create_token_budget(token_budget=1)
add_token_budget(budget, 1, 0)
remaining_swapped, output = scheduler._schedule_swapped(
remaining_swapped, budget, curr_loras, policy)
output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 1
assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 0
@ -717,8 +700,6 @@ def test_schedule_swapped_max_token_budget():
def test_schedule_swapped_max_seqs():
scheduler = initialize_scheduler()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out: List[Tuple[int, int]] = []
for i in range(4):
@ -726,11 +707,11 @@ def test_schedule_swapped_max_seqs():
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
scheduler._add_seq_group_to_swapped(seq_group)
budget = create_token_budget(max_num_seqs=2)
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 2
assert budget.num_batched_tokens == 2
assert budget.num_curr_seqs == 2
@ -738,8 +719,8 @@ def test_schedule_swapped_max_seqs():
assert len(output.prefill_seq_groups) == 0
# Verify num_curr_seqs are respected.
remaining_swapped, output = scheduler._schedule_swapped(
remaining_swapped, budget, curr_loras, policy)
output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 2
assert budget.num_batched_tokens == 2
assert budget.num_curr_seqs == 2
@ -750,8 +731,6 @@ def test_schedule_swapped_max_seqs():
def test_schedule_swapped_max_loras():
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
scheduler = initialize_scheduler(lora_config=lora_config)
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras: Set[int] = set()
blocks_to_swap_out: List[Tuple[int, int]] = []
for i in range(2):
@ -764,11 +743,11 @@ def test_schedule_swapped_max_loras():
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
scheduler._add_seq_group_to_swapped(seq_group)
budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 1
assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 1
@ -779,8 +758,6 @@ def test_schedule_swapped_max_loras():
def test_schedule_swapped_cannot_swap_in():
scheduler = initialize_scheduler()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out: List[Tuple[int, int]] = []
for _ in range(2):
@ -788,15 +765,15 @@ def test_schedule_swapped_cannot_swap_in():
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
scheduler._add_seq_group_to_swapped(seq_group)
# The last request should be swapped out.
scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
# Since we cannot swap in, none of the requests are swapped in.
budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 2
assert budget.num_batched_tokens == 0
assert budget.num_curr_seqs == 0
@ -806,8 +783,6 @@ def test_schedule_swapped_cannot_swap_in():
def test_infeasible_swap():
scheduler = initialize_scheduler()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out: List[Tuple[int, int]] = []
for _ in range(2):
@ -815,15 +790,15 @@ def test_infeasible_swap():
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
scheduler._add_seq_group_to_swapped(seq_group)
# The last request should be swapped out.
scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
# Since we cannot swap in, none of the requests are swapped in.
budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 0
assert len(output.infeasible_seq_groups) == 2
assert budget.num_batched_tokens == 0
@ -834,23 +809,21 @@ def test_infeasible_swap():
def test_schedule_swapped_blocks_to_copy():
scheduler = initialize_scheduler()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
blocks_to_swap_out: List[Tuple[int, int]] = []
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
scheduler._add_seq_group_to_swapped(seq_group)
# The last request should be swapped out.
scheduler.block_manager.append_slots = MagicMock()
scheduler.block_manager.append_slots.return_value = [(2, 3)]
budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 0
assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0

View File

@ -1,45 +0,0 @@
from collections import deque
from typing import Deque
from vllm.sequence import SequenceGroup
class Policy:
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
raise NotImplementedError
def sort_by_priority(
self,
now: float,
seq_groups: Deque[SequenceGroup],
) -> Deque[SequenceGroup]:
return deque(
sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
))
class FCFS(Policy):
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
return now - seq_group.metrics.arrival_time
class PolicyFactory:
_POLICY_REGISTRY = {'fcfs': FCFS}
@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)

View File

@ -8,7 +8,6 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import Policy, PolicyFactory
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
@ -345,6 +344,16 @@ class Scheduler:
# Add sequence groups to the waiting queue.
self.waiting.append(seq_group)
def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the running queue.
# Only for testing purposes.
self.running.append(seq_group)
def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the swapped queue.
# Only for testing purposes.
self.swapped.append(seq_group)
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a sequence group with the given ID.
@ -398,32 +407,26 @@ class Scheduler:
def _schedule_running(
self,
running_queue: deque,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
policy: Policy,
enable_chunking: bool = False,
) -> Tuple[deque, SchedulerRunningOutputs]:
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
Running queue should include decode and chunked prefill requests.
Args:
running_queue: The queue that contains running requests (i.e.,
decodes). The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any decodes are preempted.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any decodes are preempted.
policy: The sorting policy to sort running_queue.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
A tuple of remaining running queue (should be always 0) after
scheduling and SchedulerRunningOutputs.
SchedulerRunningOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: List[Tuple[int, int]] = []
@ -436,10 +439,9 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
now = time.time()
running_queue = policy.sort_by_priority(now, running_queue)
running_queue = self.running
while running_queue:
seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens(
@ -503,7 +505,7 @@ class Scheduler:
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id)
return running_queue, SchedulerRunningOutputs(
return SchedulerRunningOutputs(
decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups,
preempted=preempted,
@ -515,12 +517,10 @@ class Scheduler:
def _schedule_swapped(
self,
swapped_queue: deque,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
policy: Policy,
enable_chunking: bool = False,
) -> Tuple[deque, SchedulerSwappedInOutputs]:
) -> SchedulerSwappedInOutputs:
"""Schedule sequence groups that are swapped out.
It schedules swapped requests as long as it fits `budget` and
@ -528,20 +528,16 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
swapped_queue: The queue that contains swapped out requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are swapped in.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are swapped in.
policy: The sorting policy to sort swapped_queue.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
A tuple of remaining swapped_queue after scheduling and
SchedulerSwappedInOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
@ -549,10 +545,10 @@ class Scheduler:
blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = []
now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue)
infeasible_seq_groups: List[SequenceGroup] = []
swapped_queue = self.swapped
leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue:
seq_group = swapped_queue[0]
@ -617,7 +613,7 @@ class Scheduler:
swapped_queue.extendleft(leftover_swapped)
return swapped_queue, SchedulerSwappedInOutputs(
return SchedulerSwappedInOutputs(
decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
@ -644,11 +640,10 @@ class Scheduler:
def _schedule_prefills(
self,
waiting_queue: deque,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
) -> Tuple[deque, SchedulerPrefillOutputs]:
) -> SchedulerPrefillOutputs:
"""Schedule sequence groups that are in prefill stage.
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
@ -660,8 +655,6 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
waiting_queue: The queue that contains prefill requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
curr_loras: Currently batched lora request ids. The argument is
@ -672,14 +665,12 @@ class Scheduler:
all tokens.
Returns:
A tuple of remaining waiting_queue after scheduling and
SchedulerSwappedInOutputs.
"""
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[SequenceGroup] = []
# We don't sort waiting queue because we assume it is sorted.
# Copy the queue so that the input queue is not modified.
waiting_queue = deque([s for s in waiting_queue])
waiting_queue = self.waiting
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:
@ -758,7 +749,7 @@ class Scheduler:
if len(seq_groups) > 0:
self.prev_prompt = True
return waiting_queue, SchedulerPrefillOutputs(
return SchedulerPrefillOutputs(
seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
@ -785,53 +776,43 @@ class Scheduler:
seq_group.lora_int_id for seq_group in self.running
if seq_group.lora_int_id > 0) if self.lora_enabled else None
remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty())
remaining_running, running_scheduled = (
self.running, SchedulerRunningOutputs.create_empty())
remaining_swapped, swapped_in = (
self.swapped, SchedulerSwappedInOutputs.create_empty())
prefills = SchedulerPrefillOutputs.create_empty()
running_scheduled = SchedulerRunningOutputs.create_empty()
swapped_in = SchedulerSwappedInOutputs.create_empty()
# If any requests are swapped, prioritized swapped requests.
if not self.swapped:
remaining_waiting, prefills = self._schedule_prefills(
self.waiting, budget, curr_loras, enable_chunking=False)
prefills = self._schedule_prefills(budget,
curr_loras,
enable_chunking=False)
fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
# Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills.
if len(prefills.seq_groups) == 0:
remaining_running, running_scheduled = self._schedule_running(
self.running,
budget,
curr_loras,
fcfs_policy,
enable_chunking=False)
running_scheduled = self._schedule_running(budget,
curr_loras,
enable_chunking=False)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
if len(running_scheduled.preempted) + len(
running_scheduled.swapped_out) == 0:
remaining_swapped, swapped_in = self._schedule_swapped(
self.swapped, budget, curr_loras, fcfs_policy)
swapped_in = self._schedule_swapped(budget, curr_loras)
assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests.
self.waiting = remaining_waiting
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
self.running = remaining_running
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups])
# Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out)
preempted = (len(running_scheduled.preempted) +
len(running_scheduled.swapped_out))
@ -877,42 +858,32 @@ class Scheduler:
)
curr_loras: Set[int] = set()
remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty())
remaining_running, running_scheduled = (
self.running, SchedulerRunningOutputs.create_empty())
remaining_swapped, swapped_in = (
self.swapped, SchedulerSwappedInOutputs.create_empty())
prefills = SchedulerPrefillOutputs.create_empty()
swapped_in = SchedulerSwappedInOutputs.create_empty()
# Decoding should be always scheduled first by fcfs.
fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
remaining_running, running_scheduled = self._schedule_running(
self.running,
budget,
curr_loras,
fcfs_policy,
enable_chunking=True)
running_scheduled = self._schedule_running(budget,
curr_loras,
enable_chunking=True)
# Schedule swapped out requests.
# If preemption happens, it means we don't have space for swap-in.
if len(running_scheduled.preempted) + len(
running_scheduled.swapped_out) == 0:
remaining_swapped, swapped_in = self._schedule_swapped(
self.swapped, budget, curr_loras, fcfs_policy)
swapped_in = self._schedule_swapped(budget, curr_loras)
# Schedule new prefills.
remaining_waiting, prefills = self._schedule_prefills(
self.waiting, budget, curr_loras, enable_chunking=True)
prefills = self._schedule_prefills(budget,
curr_loras,
enable_chunking=True)
assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests.
self.waiting = remaining_waiting
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
self.running = remaining_running
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
@ -923,7 +894,6 @@ class Scheduler:
self.running.extend(
[s.seq_group for s in swapped_in.prefill_seq_groups])
# Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out)
return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups +