[core][scheduler] simplify and improve scheduler (#6867)
This commit is contained in:
parent
3c10591ef2
commit
c8a7e93273
@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
|
||||
|
||||
# Allow only 2 sequences of ~128 tokens in worst case.
|
||||
# Note 16 = 128/block_size
|
||||
"num_gpu_blocks_override": 2 * (16 + 1),
|
||||
"num_gpu_blocks_override": 2 * (16 + 2),
|
||||
}
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
|
@ -1,13 +1,12 @@
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Deque, List, Set, Tuple
|
||||
from typing import List, Set, Tuple
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # noqa
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.core.policy import PolicyFactory
|
||||
from vllm.core.scheduler import Scheduler, SchedulingBudget
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Logprob, SequenceGroup, SequenceStatus
|
||||
@ -348,10 +347,10 @@ def test_prefill_schedule_max_prompt_len():
|
||||
"""
|
||||
scheduler = initialize_scheduler(max_model_len=30)
|
||||
_, seq_group = create_dummy_prompt("0", prompt_length=60)
|
||||
waiting = deque([seq_group])
|
||||
scheduler.add_seq_group(seq_group)
|
||||
budget = create_token_budget()
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 1
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
@ -364,15 +363,14 @@ def test_prefill_schedule_token_budget():
|
||||
Test token budget respected.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting: Deque[SequenceGroup] = deque()
|
||||
budget = create_token_budget(token_budget=0)
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
# 0 token budget == nothing is scheduled.
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
@ -381,8 +379,8 @@ def test_prefill_schedule_token_budget():
|
||||
|
||||
# 60 token budget == 1 request scheduled.
|
||||
budget = create_token_budget(token_budget=60)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 1
|
||||
assert budget.num_batched_tokens == 60
|
||||
@ -391,14 +389,13 @@ def test_prefill_schedule_token_budget():
|
||||
|
||||
# Test when current_batched_tokens respected.
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
budget = create_token_budget(token_budget=60)
|
||||
add_token_budget(budget, 30, 0)
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
# Cannot schedule a prompt that doesn't fit the budget.
|
||||
waiting.append(seq_group)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 30
|
||||
@ -406,8 +403,8 @@ def test_prefill_schedule_token_budget():
|
||||
assert len(remaining_waiting) == 1
|
||||
budget = create_token_budget(token_budget=90)
|
||||
add_token_budget(budget, 30, 0)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.seq_groups) == 1
|
||||
assert budget.num_batched_tokens == 90
|
||||
assert budget.num_curr_seqs == 1
|
||||
@ -419,13 +416,12 @@ def test_prefill_schedule_max_seqs():
|
||||
Test max seq respected.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting: Deque[SequenceGroup] = deque()
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 2
|
||||
assert budget.num_batched_tokens == 120
|
||||
@ -433,13 +429,13 @@ def test_prefill_schedule_max_seqs():
|
||||
assert len(remaining_waiting) == 1
|
||||
|
||||
# Verify curr_num_seqs respected.
|
||||
waiting = deque()
|
||||
scheduler.waiting = deque()
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
add_token_budget(budget, 0, 2)
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
@ -453,7 +449,6 @@ def test_prefill_schedule_max_lora():
|
||||
"""
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
|
||||
scheduler = initialize_scheduler(lora_config=lora_config)
|
||||
waiting: Deque[SequenceGroup] = deque()
|
||||
budget = create_token_budget(token_budget=120)
|
||||
curr_loras: Set[int] = set()
|
||||
for i in range(2):
|
||||
@ -463,7 +458,7 @@ def test_prefill_schedule_max_lora():
|
||||
lora_name=str(i),
|
||||
lora_int_id=i + 1,
|
||||
lora_path="abc"))
|
||||
waiting.append(seq_group)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
# Add two more requests to verify lora is prioritized.
|
||||
# 0: Lora, 1: Lora, 2: regular, 3: regular
|
||||
# In the first iteration, index 0, 2 is scheduled.
|
||||
@ -471,10 +466,10 @@ def test_prefill_schedule_max_lora():
|
||||
# prioritized. Verify that.
|
||||
for i in range(2, 4):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
# Schedule 2 requests (0 and 2)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, curr_loras)
|
||||
output = scheduler._schedule_prefills(budget, curr_loras)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 2
|
||||
assert budget.num_batched_tokens == 120
|
||||
@ -485,8 +480,8 @@ def test_prefill_schedule_max_lora():
|
||||
# Reset curr_loras so that it can be scheduled.
|
||||
curr_loras = set()
|
||||
budget = create_token_budget(token_budget=60)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
remaining_waiting, budget, curr_loras)
|
||||
output = scheduler._schedule_prefills(budget, curr_loras)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.seq_groups) == 1
|
||||
assert output.seq_groups[0].seq_group.request_id == "1"
|
||||
assert len(remaining_waiting) == 1
|
||||
@ -499,31 +494,29 @@ def test_prefill_schedule_no_block_manager_capacity():
|
||||
Test sequence cannot be scheduled due to block manager has no capacity.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting: Deque[SequenceGroup] = deque()
|
||||
budget = create_token_budget()
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
scheduler.block_manager.can_allocate = MagicMock()
|
||||
scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
|
||||
remainig_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(remainig_waiting) == 3
|
||||
assert len(remaining_waiting) == 3
|
||||
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
budget = create_token_budget()
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
scheduler.block_manager.can_allocate = MagicMock()
|
||||
scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
output = scheduler._schedule_prefills(budget, None)
|
||||
remaining_waiting = scheduler.waiting
|
||||
assert len(output.ignored_seq_groups) == 3
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
@ -536,14 +529,12 @@ def test_decode_schedule_preempted():
|
||||
Test decodes cannot be scheduled and preempted.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
running: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
running.append(seq_group)
|
||||
scheduler._add_seq_group_to_running(seq_group)
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
@ -555,8 +546,8 @@ def test_decode_schedule_preempted():
|
||||
# 1 cannot be scheduled, and the lowest priority (request 2)
|
||||
# should be preempted. 1 will also be preempted.
|
||||
budget = create_token_budget()
|
||||
remainig_running, output = scheduler._schedule_running(
|
||||
running, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_running(budget, curr_loras)
|
||||
remainig_running = scheduler.running
|
||||
assert len(remainig_running) == 0
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
@ -577,14 +568,12 @@ def test_decode_swap_beam_search():
|
||||
Test best_of > 1 swap out blocks
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
running: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
budget = create_token_budget()
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
running.append(seq_group)
|
||||
scheduler._add_seq_group_to_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
budget.add_num_seqs(seq_group.request_id,
|
||||
seq_group.get_max_num_running_seqs())
|
||||
@ -603,8 +592,8 @@ def test_decode_swap_beam_search():
|
||||
expected_swap_mapping = [("5", "7")]
|
||||
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
|
||||
|
||||
remainig_running, output = scheduler._schedule_running(
|
||||
running, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_running(budget, curr_loras)
|
||||
remainig_running = scheduler.running
|
||||
assert len(remainig_running) == 0
|
||||
assert len(output.decode_seq_groups) == 2
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
@ -628,20 +617,18 @@ def test_schedule_decode_blocks_to_copy_update():
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
running: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
running.append(seq_group)
|
||||
scheduler._add_seq_group_to_running(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.append_slots = MagicMock()
|
||||
scheduler.block_manager.append_slots.return_value = [(2, 3)]
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_running, output = scheduler._schedule_running(
|
||||
running, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_running(budget, curr_loras)
|
||||
remaining_running = scheduler.running
|
||||
assert len(remaining_running) == 0
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
@ -656,19 +643,17 @@ def test_schedule_decode_blocks_to_copy_update():
|
||||
|
||||
def test_schedule_swapped_simple():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 0
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 2
|
||||
@ -683,8 +668,6 @@ def test_schedule_swapped_simple():
|
||||
|
||||
def test_schedule_swapped_max_token_budget():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for _ in range(2):
|
||||
@ -692,11 +675,11 @@ def test_schedule_swapped_max_token_budget():
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
budget = create_token_budget(token_budget=1)
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 1
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 2
|
||||
@ -706,8 +689,8 @@ def test_schedule_swapped_max_token_budget():
|
||||
# Verify num_batched_tokens are respected.
|
||||
budget = create_token_budget(token_budget=1)
|
||||
add_token_budget(budget, 1, 0)
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
remaining_swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 1
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 0
|
||||
@ -717,8 +700,6 @@ def test_schedule_swapped_max_token_budget():
|
||||
|
||||
def test_schedule_swapped_max_seqs():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for i in range(4):
|
||||
@ -726,11 +707,11 @@ def test_schedule_swapped_max_seqs():
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 2
|
||||
assert budget.num_batched_tokens == 2
|
||||
assert budget.num_curr_seqs == 2
|
||||
@ -738,8 +719,8 @@ def test_schedule_swapped_max_seqs():
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
# Verify num_curr_seqs are respected.
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
remaining_swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 2
|
||||
assert budget.num_batched_tokens == 2
|
||||
assert budget.num_curr_seqs == 2
|
||||
@ -750,8 +731,6 @@ def test_schedule_swapped_max_seqs():
|
||||
def test_schedule_swapped_max_loras():
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
|
||||
scheduler = initialize_scheduler(lora_config=lora_config)
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras: Set[int] = set()
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for i in range(2):
|
||||
@ -764,11 +743,11 @@ def test_schedule_swapped_max_loras():
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 1
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 1
|
||||
@ -779,8 +758,6 @@ def test_schedule_swapped_max_loras():
|
||||
|
||||
def test_schedule_swapped_cannot_swap_in():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for _ in range(2):
|
||||
@ -788,15 +765,15 @@ def test_schedule_swapped_cannot_swap_in():
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
|
||||
# Since we cannot swap in, none of the requests are swapped in.
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 2
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
@ -806,8 +783,6 @@ def test_schedule_swapped_cannot_swap_in():
|
||||
|
||||
def test_infeasible_swap():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for _ in range(2):
|
||||
@ -815,15 +790,15 @@ def test_infeasible_swap():
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
|
||||
# Since we cannot swap in, none of the requests are swapped in.
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 0
|
||||
assert len(output.infeasible_seq_groups) == 2
|
||||
assert budget.num_batched_tokens == 0
|
||||
@ -834,23 +809,21 @@ def test_infeasible_swap():
|
||||
|
||||
def test_schedule_swapped_blocks_to_copy():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped: Deque[SequenceGroup] = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.append_slots = MagicMock()
|
||||
scheduler.block_manager.append_slots.return_value = [(2, 3)]
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 0
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
@ -1,45 +0,0 @@
|
||||
from collections import deque
|
||||
from typing import Deque
|
||||
|
||||
from vllm.sequence import SequenceGroup
|
||||
|
||||
|
||||
class Policy:
|
||||
|
||||
def get_priority(
|
||||
self,
|
||||
now: float,
|
||||
seq_group: SequenceGroup,
|
||||
) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
def sort_by_priority(
|
||||
self,
|
||||
now: float,
|
||||
seq_groups: Deque[SequenceGroup],
|
||||
) -> Deque[SequenceGroup]:
|
||||
return deque(
|
||||
sorted(
|
||||
seq_groups,
|
||||
key=lambda seq_group: self.get_priority(now, seq_group),
|
||||
reverse=True,
|
||||
))
|
||||
|
||||
|
||||
class FCFS(Policy):
|
||||
|
||||
def get_priority(
|
||||
self,
|
||||
now: float,
|
||||
seq_group: SequenceGroup,
|
||||
) -> float:
|
||||
return now - seq_group.metrics.arrival_time
|
||||
|
||||
|
||||
class PolicyFactory:
|
||||
|
||||
_POLICY_REGISTRY = {'fcfs': FCFS}
|
||||
|
||||
@classmethod
|
||||
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
|
||||
return cls._POLICY_REGISTRY[policy_name](**kwargs)
|
@ -8,7 +8,6 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.core.policy import Policy, PolicyFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
@ -345,6 +344,16 @@ class Scheduler:
|
||||
# Add sequence groups to the waiting queue.
|
||||
self.waiting.append(seq_group)
|
||||
|
||||
def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
|
||||
# Add sequence groups to the running queue.
|
||||
# Only for testing purposes.
|
||||
self.running.append(seq_group)
|
||||
|
||||
def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
|
||||
# Add sequence groups to the swapped queue.
|
||||
# Only for testing purposes.
|
||||
self.swapped.append(seq_group)
|
||||
|
||||
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
"""Aborts a sequence group with the given ID.
|
||||
|
||||
@ -398,32 +407,26 @@ class Scheduler:
|
||||
|
||||
def _schedule_running(
|
||||
self,
|
||||
running_queue: deque,
|
||||
budget: SchedulingBudget,
|
||||
curr_loras: Optional[Set[int]],
|
||||
policy: Policy,
|
||||
enable_chunking: bool = False,
|
||||
) -> Tuple[deque, SchedulerRunningOutputs]:
|
||||
) -> SchedulerRunningOutputs:
|
||||
"""Schedule sequence groups that are running.
|
||||
|
||||
Running queue should include decode and chunked prefill requests.
|
||||
|
||||
Args:
|
||||
running_queue: The queue that contains running requests (i.e.,
|
||||
decodes). The given arguments are NOT in-place modified.
|
||||
budget: The scheduling budget. The argument is in-place updated
|
||||
when any decodes are preempted.
|
||||
curr_loras: Currently batched lora request ids. The argument is
|
||||
in-place updated when any decodes are preempted.
|
||||
policy: The sorting policy to sort running_queue.
|
||||
enable_chunking: If True, seq group can be chunked and only a
|
||||
chunked number of tokens are scheduled if
|
||||
`budget.num_batched_tokens` has not enough capacity to schedule
|
||||
all tokens.
|
||||
|
||||
Returns:
|
||||
A tuple of remaining running queue (should be always 0) after
|
||||
scheduling and SchedulerRunningOutputs.
|
||||
SchedulerRunningOutputs.
|
||||
"""
|
||||
# Blocks that need to be swapped or copied before model execution.
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
@ -436,10 +439,9 @@ class Scheduler:
|
||||
|
||||
# NOTE(woosuk): Preemption happens only when there is no available slot
|
||||
# to keep all the sequence groups in the RUNNING state.
|
||||
# In this case, the policy is responsible for deciding which sequence
|
||||
# groups to preempt.
|
||||
now = time.time()
|
||||
running_queue = policy.sort_by_priority(now, running_queue)
|
||||
|
||||
running_queue = self.running
|
||||
|
||||
while running_queue:
|
||||
seq_group = running_queue[0]
|
||||
num_running_tokens = self._get_num_new_tokens(
|
||||
@ -503,7 +505,7 @@ class Scheduler:
|
||||
if curr_loras is not None and seq_group.lora_int_id > 0:
|
||||
curr_loras.add(seq_group.lora_int_id)
|
||||
|
||||
return running_queue, SchedulerRunningOutputs(
|
||||
return SchedulerRunningOutputs(
|
||||
decode_seq_groups=decode_seq_groups,
|
||||
prefill_seq_groups=prefill_seq_groups,
|
||||
preempted=preempted,
|
||||
@ -515,12 +517,10 @@ class Scheduler:
|
||||
|
||||
def _schedule_swapped(
|
||||
self,
|
||||
swapped_queue: deque,
|
||||
budget: SchedulingBudget,
|
||||
curr_loras: Optional[Set[int]],
|
||||
policy: Policy,
|
||||
enable_chunking: bool = False,
|
||||
) -> Tuple[deque, SchedulerSwappedInOutputs]:
|
||||
) -> SchedulerSwappedInOutputs:
|
||||
"""Schedule sequence groups that are swapped out.
|
||||
|
||||
It schedules swapped requests as long as it fits `budget` and
|
||||
@ -528,20 +528,16 @@ class Scheduler:
|
||||
`budget` and `curr_loras` are updated based on scheduled seq_groups.
|
||||
|
||||
Args:
|
||||
swapped_queue: The queue that contains swapped out requests.
|
||||
The given arguments are NOT in-place modified.
|
||||
budget: The scheduling budget. The argument is in-place updated
|
||||
when any requests are swapped in.
|
||||
curr_loras: Currently batched lora request ids. The argument is
|
||||
in-place updated when any requests are swapped in.
|
||||
policy: The sorting policy to sort swapped_queue.
|
||||
enable_chunking: If True, seq group can be chunked and only a
|
||||
chunked number of tokens are scheduled if
|
||||
`budget.num_batched_tokens` has not enough capacity to schedule
|
||||
all tokens.
|
||||
|
||||
Returns:
|
||||
A tuple of remaining swapped_queue after scheduling and
|
||||
SchedulerSwappedInOutputs.
|
||||
"""
|
||||
# Blocks that need to be swapped or copied before model execution.
|
||||
@ -549,10 +545,10 @@ class Scheduler:
|
||||
blocks_to_copy: List[Tuple[int, int]] = []
|
||||
decode_seq_groups: List[ScheduledSequenceGroup] = []
|
||||
prefill_seq_groups: List[ScheduledSequenceGroup] = []
|
||||
now = time.time()
|
||||
swapped_queue = policy.sort_by_priority(now, swapped_queue)
|
||||
infeasible_seq_groups: List[SequenceGroup] = []
|
||||
|
||||
swapped_queue = self.swapped
|
||||
|
||||
leftover_swapped: Deque[SequenceGroup] = deque()
|
||||
while swapped_queue:
|
||||
seq_group = swapped_queue[0]
|
||||
@ -617,7 +613,7 @@ class Scheduler:
|
||||
|
||||
swapped_queue.extendleft(leftover_swapped)
|
||||
|
||||
return swapped_queue, SchedulerSwappedInOutputs(
|
||||
return SchedulerSwappedInOutputs(
|
||||
decode_seq_groups=decode_seq_groups,
|
||||
prefill_seq_groups=prefill_seq_groups,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
@ -644,11 +640,10 @@ class Scheduler:
|
||||
|
||||
def _schedule_prefills(
|
||||
self,
|
||||
waiting_queue: deque,
|
||||
budget: SchedulingBudget,
|
||||
curr_loras: Optional[Set[int]],
|
||||
enable_chunking: bool = False,
|
||||
) -> Tuple[deque, SchedulerPrefillOutputs]:
|
||||
) -> SchedulerPrefillOutputs:
|
||||
"""Schedule sequence groups that are in prefill stage.
|
||||
|
||||
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
|
||||
@ -660,8 +655,6 @@ class Scheduler:
|
||||
`budget` and `curr_loras` are updated based on scheduled seq_groups.
|
||||
|
||||
Args:
|
||||
waiting_queue: The queue that contains prefill requests.
|
||||
The given arguments are NOT in-place modified.
|
||||
budget: The scheduling budget. The argument is in-place updated
|
||||
when any requests are scheduled.
|
||||
curr_loras: Currently batched lora request ids. The argument is
|
||||
@ -672,14 +665,12 @@ class Scheduler:
|
||||
all tokens.
|
||||
|
||||
Returns:
|
||||
A tuple of remaining waiting_queue after scheduling and
|
||||
SchedulerSwappedInOutputs.
|
||||
"""
|
||||
ignored_seq_groups: List[SequenceGroup] = []
|
||||
seq_groups: List[SequenceGroup] = []
|
||||
# We don't sort waiting queue because we assume it is sorted.
|
||||
# Copy the queue so that the input queue is not modified.
|
||||
waiting_queue = deque([s for s in waiting_queue])
|
||||
|
||||
waiting_queue = self.waiting
|
||||
|
||||
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
|
||||
while self._passed_delay(time.time()) and waiting_queue:
|
||||
@ -758,7 +749,7 @@ class Scheduler:
|
||||
if len(seq_groups) > 0:
|
||||
self.prev_prompt = True
|
||||
|
||||
return waiting_queue, SchedulerPrefillOutputs(
|
||||
return SchedulerPrefillOutputs(
|
||||
seq_groups=seq_groups,
|
||||
ignored_seq_groups=ignored_seq_groups,
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
|
||||
@ -785,53 +776,43 @@ class Scheduler:
|
||||
seq_group.lora_int_id for seq_group in self.running
|
||||
if seq_group.lora_int_id > 0) if self.lora_enabled else None
|
||||
|
||||
remaining_waiting, prefills = (self.waiting,
|
||||
SchedulerPrefillOutputs.create_empty())
|
||||
remaining_running, running_scheduled = (
|
||||
self.running, SchedulerRunningOutputs.create_empty())
|
||||
remaining_swapped, swapped_in = (
|
||||
self.swapped, SchedulerSwappedInOutputs.create_empty())
|
||||
prefills = SchedulerPrefillOutputs.create_empty()
|
||||
running_scheduled = SchedulerRunningOutputs.create_empty()
|
||||
swapped_in = SchedulerSwappedInOutputs.create_empty()
|
||||
|
||||
# If any requests are swapped, prioritized swapped requests.
|
||||
if not self.swapped:
|
||||
remaining_waiting, prefills = self._schedule_prefills(
|
||||
self.waiting, budget, curr_loras, enable_chunking=False)
|
||||
prefills = self._schedule_prefills(budget,
|
||||
curr_loras,
|
||||
enable_chunking=False)
|
||||
|
||||
fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
# Don't schedule decodes if prefills are scheduled.
|
||||
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
|
||||
# only contains decode requests, not chunked prefills.
|
||||
if len(prefills.seq_groups) == 0:
|
||||
remaining_running, running_scheduled = self._schedule_running(
|
||||
self.running,
|
||||
budget,
|
||||
curr_loras,
|
||||
fcfs_policy,
|
||||
enable_chunking=False)
|
||||
running_scheduled = self._schedule_running(budget,
|
||||
curr_loras,
|
||||
enable_chunking=False)
|
||||
|
||||
# If any sequence group is preempted, do not swap in any sequence
|
||||
# group. because it means there's no slot for new running requests.
|
||||
if len(running_scheduled.preempted) + len(
|
||||
running_scheduled.swapped_out) == 0:
|
||||
remaining_swapped, swapped_in = self._schedule_swapped(
|
||||
self.swapped, budget, curr_loras, fcfs_policy)
|
||||
swapped_in = self._schedule_swapped(budget, curr_loras)
|
||||
|
||||
assert (budget.num_batched_tokens <=
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
|
||||
|
||||
# Update waiting requests.
|
||||
self.waiting = remaining_waiting
|
||||
self.waiting.extendleft(running_scheduled.preempted)
|
||||
# Update new running requests.
|
||||
self.running = remaining_running
|
||||
self.running.extend([s.seq_group for s in prefills.seq_groups])
|
||||
self.running.extend(
|
||||
[s.seq_group for s in running_scheduled.decode_seq_groups])
|
||||
self.running.extend(
|
||||
[s.seq_group for s in swapped_in.decode_seq_groups])
|
||||
# Update swapped requests.
|
||||
self.swapped = remaining_swapped
|
||||
self.swapped.extend(running_scheduled.swapped_out)
|
||||
preempted = (len(running_scheduled.preempted) +
|
||||
len(running_scheduled.swapped_out))
|
||||
@ -877,42 +858,32 @@ class Scheduler:
|
||||
)
|
||||
curr_loras: Set[int] = set()
|
||||
|
||||
remaining_waiting, prefills = (self.waiting,
|
||||
SchedulerPrefillOutputs.create_empty())
|
||||
remaining_running, running_scheduled = (
|
||||
self.running, SchedulerRunningOutputs.create_empty())
|
||||
remaining_swapped, swapped_in = (
|
||||
self.swapped, SchedulerSwappedInOutputs.create_empty())
|
||||
prefills = SchedulerPrefillOutputs.create_empty()
|
||||
swapped_in = SchedulerSwappedInOutputs.create_empty()
|
||||
|
||||
# Decoding should be always scheduled first by fcfs.
|
||||
fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
remaining_running, running_scheduled = self._schedule_running(
|
||||
self.running,
|
||||
budget,
|
||||
curr_loras,
|
||||
fcfs_policy,
|
||||
enable_chunking=True)
|
||||
running_scheduled = self._schedule_running(budget,
|
||||
curr_loras,
|
||||
enable_chunking=True)
|
||||
|
||||
# Schedule swapped out requests.
|
||||
# If preemption happens, it means we don't have space for swap-in.
|
||||
if len(running_scheduled.preempted) + len(
|
||||
running_scheduled.swapped_out) == 0:
|
||||
remaining_swapped, swapped_in = self._schedule_swapped(
|
||||
self.swapped, budget, curr_loras, fcfs_policy)
|
||||
swapped_in = self._schedule_swapped(budget, curr_loras)
|
||||
|
||||
# Schedule new prefills.
|
||||
remaining_waiting, prefills = self._schedule_prefills(
|
||||
self.waiting, budget, curr_loras, enable_chunking=True)
|
||||
prefills = self._schedule_prefills(budget,
|
||||
curr_loras,
|
||||
enable_chunking=True)
|
||||
|
||||
assert (budget.num_batched_tokens <=
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
|
||||
|
||||
# Update waiting requests.
|
||||
self.waiting = remaining_waiting
|
||||
self.waiting.extendleft(running_scheduled.preempted)
|
||||
# Update new running requests.
|
||||
self.running = remaining_running
|
||||
self.running.extend([s.seq_group for s in prefills.seq_groups])
|
||||
self.running.extend(
|
||||
[s.seq_group for s in running_scheduled.decode_seq_groups])
|
||||
@ -923,7 +894,6 @@ class Scheduler:
|
||||
self.running.extend(
|
||||
[s.seq_group for s in swapped_in.prefill_seq_groups])
|
||||
# Update swapped requests.
|
||||
self.swapped = remaining_swapped
|
||||
self.swapped.extend(running_scheduled.swapped_out)
|
||||
return SchedulerOutputs(
|
||||
scheduled_seq_groups=(prefills.seq_groups +
|
||||
|
Loading…
x
Reference in New Issue
Block a user