2023-03-30 14:51:46 -07:00
|
|
|
import enum
|
|
|
|
import time
|
2023-05-10 01:57:07 -07:00
|
|
|
from typing import Dict, List, Optional, Tuple
|
2023-02-13 02:39:53 +00:00
|
|
|
|
2023-05-09 15:30:12 -07:00
|
|
|
from cacheflow.core.block_manager import BlockSpaceManager
|
2023-05-10 01:06:53 -07:00
|
|
|
from cacheflow.logger import init_logger
|
2023-05-09 15:30:12 -07:00
|
|
|
from cacheflow.core.policy import PolicyFactory
|
2023-02-24 11:46:43 +00:00
|
|
|
from cacheflow.sampling_params import SamplingParams
|
2023-05-10 01:57:07 -07:00
|
|
|
from cacheflow.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata,
|
|
|
|
SequenceOutputs, SequenceStatus)
|
2023-02-13 02:39:53 +00:00
|
|
|
|
2023-05-10 01:06:53 -07:00
|
|
|
logger = init_logger(__name__)
|
2023-05-10 01:57:07 -07:00
|
|
|
|
2023-05-10 01:06:53 -07:00
|
|
|
_LOGGING_INTERVAL_SEC = 10
|
|
|
|
|
|
|
|
|
2023-03-30 14:51:46 -07:00
|
|
|
class PreemptionMode(enum.Enum):
|
|
|
|
"""Preemption modes.
|
|
|
|
|
|
|
|
1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
|
|
|
|
and swap them back in when the sequences are resumed.
|
|
|
|
2. Recomputation: Discard the blocks of the preempted sequences and
|
|
|
|
recompute them when the sequences are resumed, treating the sequences as
|
|
|
|
new prompts.
|
|
|
|
"""
|
|
|
|
SWAP = enum.auto()
|
|
|
|
RECOMPUTE = enum.auto()
|
|
|
|
|
|
|
|
|
2023-02-13 02:39:53 +00:00
|
|
|
class Scheduler:
|
|
|
|
|
2023-02-13 18:51:33 +00:00
|
|
|
def __init__(
|
2023-02-13 02:39:53 +00:00
|
|
|
self,
|
2023-02-13 09:37:00 +00:00
|
|
|
controllers: List,
|
2023-02-13 02:39:53 +00:00
|
|
|
block_size: int,
|
|
|
|
num_gpu_blocks: int,
|
|
|
|
num_cpu_blocks: int,
|
2023-03-11 23:23:14 -08:00
|
|
|
max_num_batched_tokens: int,
|
2023-04-12 15:03:49 -07:00
|
|
|
max_num_sequences: int,
|
2023-05-10 01:06:53 -07:00
|
|
|
log_stats: bool,
|
2023-02-13 02:39:53 +00:00
|
|
|
) -> None:
|
2023-02-13 09:37:00 +00:00
|
|
|
self.controllers = controllers
|
|
|
|
self.block_size = block_size
|
|
|
|
self.num_gpu_blocks = num_gpu_blocks
|
|
|
|
self.num_cpu_blocks = num_cpu_blocks
|
2023-03-11 23:23:14 -08:00
|
|
|
self.max_num_batched_tokens = max_num_batched_tokens
|
2023-04-12 15:03:49 -07:00
|
|
|
self.max_num_sequences = max_num_sequences
|
2023-05-10 01:06:53 -07:00
|
|
|
self.log_stats = log_stats
|
2023-02-13 09:37:00 +00:00
|
|
|
|
2023-03-30 14:51:46 -07:00
|
|
|
# Instantiate the scheduling policy.
|
|
|
|
self.policy = PolicyFactory.get_policy(policy_name='fcfs')
|
2023-02-13 09:37:00 +00:00
|
|
|
# Create the block space manager.
|
2023-02-13 02:39:53 +00:00
|
|
|
self.block_manager = BlockSpaceManager(
|
|
|
|
block_size=block_size,
|
|
|
|
num_gpu_blocks=num_gpu_blocks,
|
|
|
|
num_cpu_blocks=num_cpu_blocks,
|
|
|
|
)
|
|
|
|
|
2023-03-30 14:51:46 -07:00
|
|
|
# Sequence groups in the WAITING state.
|
|
|
|
self.waiting: List[SequenceGroup] = []
|
|
|
|
# Sequence groups in the RUNNING state.
|
2023-02-23 07:54:20 +00:00
|
|
|
self.running: List[SequenceGroup] = []
|
2023-02-13 02:39:53 +00:00
|
|
|
# Mapping: group_id -> num_steps.
|
|
|
|
self.num_steps: Dict[int, int] = {}
|
2023-02-24 11:46:43 +00:00
|
|
|
# Mapping: group_id -> sampling params.
|
|
|
|
self.sampling_params: Dict[int, SamplingParams] = {}
|
2023-03-30 14:51:46 -07:00
|
|
|
# Sequence groups in the SWAPPED state.
|
2023-02-13 02:39:53 +00:00
|
|
|
self.swapped: List[SequenceGroup] = []
|
2023-02-13 09:37:00 +00:00
|
|
|
|
2023-05-10 01:06:53 -07:00
|
|
|
self.last_logging_time: float = 0.0
|
|
|
|
# List[timestamp, num_tokens]
|
|
|
|
self.num_input_tokens: List[Tuple[float, int]] = []
|
2023-04-12 15:03:49 -07:00
|
|
|
|
2023-03-29 14:48:56 +08:00
|
|
|
def add_sequence_groups(
|
|
|
|
self,
|
2023-03-30 14:51:46 -07:00
|
|
|
seq_groups: List[Tuple[SequenceGroup, SamplingParams]],
|
2023-03-29 14:48:56 +08:00
|
|
|
) -> None:
|
2023-03-30 14:51:46 -07:00
|
|
|
# Add sequence groups to the waiting queue.
|
|
|
|
for seq_group, sampling_params in seq_groups:
|
|
|
|
self.waiting.append(seq_group)
|
2023-02-24 11:46:43 +00:00
|
|
|
self.sampling_params[seq_group.group_id] = sampling_params
|
|
|
|
|
2023-03-30 14:51:46 -07:00
|
|
|
def _schedule(
|
2023-02-24 10:22:39 +00:00
|
|
|
self,
|
2023-03-30 14:51:46 -07:00
|
|
|
) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, List[int]], List[int]]:
|
2023-02-24 10:22:39 +00:00
|
|
|
# Blocks that need to be swaped or copied before model execution.
|
|
|
|
blocks_to_swap_in: Dict[int, int] = {}
|
|
|
|
blocks_to_swap_out: Dict[int, int] = {}
|
2023-03-10 09:58:21 -08:00
|
|
|
blocks_to_copy: Dict[int, List[int]] = {}
|
2023-02-24 10:22:39 +00:00
|
|
|
|
2023-03-30 14:51:46 -07:00
|
|
|
# Fix the current time.
|
|
|
|
now = time.time()
|
|
|
|
|
|
|
|
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
|
|
|
|
# in order to minimize the preemption overheads.
|
|
|
|
# Preemption happens only when there is no available slot to keep all
|
|
|
|
# the sequence groups in the RUNNING state.
|
|
|
|
# In this case, the policy is responsible for deciding which sequence
|
|
|
|
# groups to preempt.
|
|
|
|
self.running = self.policy.sort_by_priority(now, self.running)
|
|
|
|
|
|
|
|
# Reserve new token slots for the running sequence groups.
|
|
|
|
running: List[SequenceGroup] = []
|
|
|
|
preempted: List[SequenceGroup] = []
|
|
|
|
while self.running:
|
|
|
|
seq_group = self.running.pop(0)
|
2023-05-10 00:58:31 -07:00
|
|
|
while not self.block_manager.can_append_slot(seq_group):
|
2023-03-30 14:51:46 -07:00
|
|
|
if self.running:
|
|
|
|
# Preempt the lowest-priority sequence groups.
|
|
|
|
victim_seq_group = self.running.pop(-1)
|
|
|
|
self._preempt(victim_seq_group, blocks_to_swap_out)
|
|
|
|
preempted.append(victim_seq_group)
|
|
|
|
else:
|
|
|
|
# No other sequence groups can be preempted.
|
|
|
|
# Preempt the current sequence group.
|
|
|
|
self._preempt(seq_group, blocks_to_swap_out)
|
|
|
|
preempted.append(seq_group)
|
2023-02-13 02:39:53 +00:00
|
|
|
break
|
|
|
|
else:
|
2023-03-30 14:51:46 -07:00
|
|
|
# Append new slots to the sequence group.
|
2023-05-10 00:58:31 -07:00
|
|
|
self._append_slot(seq_group, blocks_to_copy)
|
2023-03-30 14:51:46 -07:00
|
|
|
running.append(seq_group)
|
|
|
|
self.running = running
|
|
|
|
|
|
|
|
# Swap in the sequence groups in the SWAPPED state if possible.
|
|
|
|
self.swapped = self.policy.sort_by_priority(now, self.swapped)
|
2023-04-15 09:03:24 -07:00
|
|
|
while self.swapped and not blocks_to_swap_out:
|
2023-03-30 14:51:46 -07:00
|
|
|
seq_group = self.swapped[0]
|
|
|
|
# If the sequence group has been preempted in this step, stop.
|
|
|
|
if seq_group in preempted:
|
|
|
|
break
|
|
|
|
# If the sequence group cannot be swapped in, stop.
|
|
|
|
if not self.block_manager.can_swap_in(seq_group):
|
2023-02-13 02:39:53 +00:00
|
|
|
break
|
|
|
|
|
2023-04-12 15:03:49 -07:00
|
|
|
# The total number of sequences in the RUNNING state should not
|
|
|
|
# exceed the maximum number of sequences.
|
|
|
|
num_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
|
|
|
if len(self.running) + num_seqs > self.max_num_sequences:
|
|
|
|
break
|
|
|
|
|
2023-03-30 14:51:46 -07:00
|
|
|
seq_group = self.swapped.pop(0)
|
|
|
|
self._swap_in(seq_group, blocks_to_swap_in)
|
2023-05-10 00:58:31 -07:00
|
|
|
self._append_slot(seq_group, blocks_to_copy)
|
2023-03-30 14:51:46 -07:00
|
|
|
self.running.append(seq_group)
|
2023-03-10 09:58:21 -08:00
|
|
|
|
2023-02-23 07:54:20 +00:00
|
|
|
num_batched_tokens = sum(
|
|
|
|
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
|
|
|
for seq_group in self.running
|
|
|
|
)
|
|
|
|
|
2023-03-30 14:51:46 -07:00
|
|
|
# Join waiting sequences if possible.
|
|
|
|
prompt_group_ids: List[int] = []
|
|
|
|
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
|
|
|
|
# prioritized over the sequence groups in the WAITING state.
|
|
|
|
# This is because we want to bound the amount of CPU memory taken by
|
|
|
|
# the swapped sequence groups.
|
2023-02-13 02:39:53 +00:00
|
|
|
if not self.swapped:
|
2023-05-10 01:57:07 -07:00
|
|
|
# Optimization: We do not sort the waiting queue since the preempted
|
|
|
|
# sequence groups are added to the front and the new sequence groups
|
|
|
|
# are added to the back.
|
2023-03-30 14:51:46 -07:00
|
|
|
while self.waiting:
|
|
|
|
seq_group = self.waiting[0]
|
|
|
|
# If the sequence group has been preempted in this step, stop.
|
|
|
|
if seq_group in preempted:
|
|
|
|
break
|
|
|
|
# If the sequence group cannot be allocated, stop.
|
|
|
|
if not self.block_manager.can_allocate(seq_group):
|
|
|
|
break
|
|
|
|
|
|
|
|
# If the number of batched tokens exceeds the limit, stop.
|
2023-02-23 07:54:20 +00:00
|
|
|
num_prompt_tokens = seq_group.seqs[0].get_len()
|
2023-03-30 14:51:46 -07:00
|
|
|
if (num_batched_tokens + num_prompt_tokens
|
|
|
|
> self.max_num_batched_tokens):
|
|
|
|
break
|
|
|
|
|
2023-04-12 15:03:49 -07:00
|
|
|
# The total number of sequences in the RUNNING state should not
|
|
|
|
# exceed the maximum number of sequences.
|
|
|
|
num_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
|
|
|
|
if len(self.running) + num_seqs > self.max_num_sequences:
|
|
|
|
break
|
|
|
|
|
2023-03-30 14:51:46 -07:00
|
|
|
seq_group = self.waiting.pop(0)
|
|
|
|
self._allocate(seq_group)
|
|
|
|
self.running.append(seq_group)
|
|
|
|
num_batched_tokens += num_prompt_tokens
|
|
|
|
prompt_group_ids.append(seq_group.group_id)
|
2023-02-13 02:39:53 +00:00
|
|
|
|
2023-05-10 01:06:53 -07:00
|
|
|
if not self.log_stats:
|
|
|
|
return (blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy,
|
|
|
|
prompt_group_ids)
|
|
|
|
|
|
|
|
now = time.time()
|
|
|
|
if num_batched_tokens > 0:
|
|
|
|
self.num_input_tokens.append((now, num_batched_tokens))
|
|
|
|
elapsed_time = now - self.last_logging_time
|
|
|
|
if elapsed_time > _LOGGING_INTERVAL_SEC:
|
|
|
|
self.last_logging_time = now
|
|
|
|
self.num_input_tokens = [
|
|
|
|
(t, n) for t, n in self.num_input_tokens
|
|
|
|
if now - t < _LOGGING_INTERVAL_SEC
|
|
|
|
]
|
|
|
|
if len(self.num_input_tokens) > 1:
|
|
|
|
total_num_tokens = sum(n for _, n in self.num_input_tokens[:-1])
|
|
|
|
window = now - self.num_input_tokens[0][0]
|
|
|
|
avg_throughput = total_num_tokens / window
|
|
|
|
else:
|
|
|
|
avg_throughput = 0.0
|
|
|
|
|
|
|
|
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
|
|
|
|
num_used_gpu_blocks = self.num_gpu_blocks - num_free_gpu_blocks
|
|
|
|
gpu_cache_usage = num_used_gpu_blocks / self.num_gpu_blocks
|
|
|
|
if self.num_cpu_blocks > 0:
|
2023-04-12 15:03:49 -07:00
|
|
|
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
|
|
|
|
num_used_cpu_blocks = self.num_cpu_blocks - num_free_cpu_blocks
|
2023-05-10 01:06:53 -07:00
|
|
|
cpu_cache_usage = num_used_cpu_blocks / self.num_cpu_blocks
|
|
|
|
else:
|
|
|
|
cpu_cache_usage = 0.0
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
f"Throughput: {avg_throughput:.1f} tokens/s, "
|
|
|
|
f"Running: {len(self.running)} reqs, "
|
|
|
|
f"Swapped: {len(self.swapped)} reqs, "
|
|
|
|
f"Pending: {len(self.waiting)} reqs, "
|
|
|
|
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
|
|
|
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
|
|
|
|
|
|
|
return (blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy,
|
2023-03-30 14:51:46 -07:00
|
|
|
prompt_group_ids)
|
|
|
|
|
|
|
|
def step(self) -> List[SequenceGroup]:
|
|
|
|
# Schedule sequence groups.
|
|
|
|
# This function call changes the internal states of the scheduler
|
|
|
|
# such as self.running, self.swapped, and self.waiting.
|
|
|
|
scheduler_output = self._schedule()
|
|
|
|
blocks_to_swap_in = scheduler_output[0]
|
|
|
|
blocks_to_swap_out = scheduler_output[1]
|
|
|
|
blocks_to_copy = scheduler_output[2]
|
|
|
|
prompt_group_ids = scheduler_output[3]
|
|
|
|
|
|
|
|
# Create input data structures.
|
2023-05-10 00:58:31 -07:00
|
|
|
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
2023-03-29 14:48:56 +08:00
|
|
|
updated_seq_groups: List[SequenceGroup] = self.running.copy()
|
|
|
|
|
2023-02-23 07:54:20 +00:00
|
|
|
for seq_group in self.running:
|
|
|
|
group_id = seq_group.group_id
|
2023-03-30 14:51:46 -07:00
|
|
|
is_prompt = group_id in prompt_group_ids
|
2023-02-23 07:54:20 +00:00
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
input_tokens: Dict[int, List[int]] = {}
|
|
|
|
seq_logprobs: Dict[int, float] = {}
|
|
|
|
block_tables: Dict[int, List[int]] = {}
|
|
|
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
2023-02-23 07:54:20 +00:00
|
|
|
seq_id = seq.seq_id
|
|
|
|
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
|
|
|
if is_prompt:
|
2023-03-10 09:58:21 -08:00
|
|
|
input_tokens[seq_id] = seq.get_token_ids()
|
2023-02-23 07:54:20 +00:00
|
|
|
else:
|
2023-03-10 09:58:21 -08:00
|
|
|
input_tokens[seq_id] = [seq.get_last_token_id()]
|
|
|
|
seq_logprobs[seq_id] = seq.cumulative_logprobs
|
|
|
|
# NOTE(woosuk): Sequences in the same group have the same
|
|
|
|
# sequence length
|
|
|
|
seq_len = seq.get_len()
|
|
|
|
|
2023-05-10 00:58:31 -07:00
|
|
|
seq_group_metadata = SequenceGroupMetadata(
|
2023-03-10 09:58:21 -08:00
|
|
|
group_id=group_id,
|
|
|
|
is_prompt=is_prompt,
|
|
|
|
input_tokens=input_tokens,
|
|
|
|
context_len=seq_len,
|
|
|
|
seq_logprobs=seq_logprobs,
|
|
|
|
sampling_params=self.sampling_params[group_id],
|
|
|
|
block_tables=block_tables,
|
|
|
|
)
|
2023-05-10 00:58:31 -07:00
|
|
|
seq_group_metadata_list.append(seq_group_metadata)
|
2023-02-23 07:54:20 +00:00
|
|
|
|
2023-03-30 14:51:46 -07:00
|
|
|
# Execute the first stage of the pipeline.
|
2023-05-10 00:58:31 -07:00
|
|
|
if seq_group_metadata_list or blocks_to_swap_in or blocks_to_swap_out:
|
2023-03-30 14:51:46 -07:00
|
|
|
# Swap in and swap out should never happen at the same time.
|
|
|
|
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
2023-03-13 13:48:38 -07:00
|
|
|
self.controllers[0].execute_stage(
|
2023-05-10 00:58:31 -07:00
|
|
|
seq_group_metadata_list,
|
2023-03-30 14:51:46 -07:00
|
|
|
blocks_to_swap_in=blocks_to_swap_in,
|
|
|
|
blocks_to_swap_out=blocks_to_swap_out,
|
|
|
|
blocks_to_copy=blocks_to_copy,
|
2023-03-13 13:48:38 -07:00
|
|
|
)
|
2023-02-14 02:25:32 +00:00
|
|
|
|
2023-03-29 14:48:56 +08:00
|
|
|
return updated_seq_groups
|
|
|
|
|
2023-02-13 02:39:53 +00:00
|
|
|
def post_step(
|
|
|
|
self,
|
2023-03-10 09:58:21 -08:00
|
|
|
seq_outputs: Dict[int, SequenceOutputs],
|
2023-02-13 02:39:53 +00:00
|
|
|
) -> None:
|
|
|
|
# Update the running sequences and free blocks.
|
2023-02-23 07:54:20 +00:00
|
|
|
for seq_group in self.running:
|
2023-02-13 02:39:53 +00:00
|
|
|
group_id = seq_group.group_id
|
|
|
|
self.num_steps[group_id] += 1
|
2023-02-24 11:46:43 +00:00
|
|
|
stop_token_ids = self.sampling_params[group_id].stop_token_ids
|
2023-02-13 02:39:53 +00:00
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
# Process beam search results before processing the next tokens.
|
2023-02-13 02:39:53 +00:00
|
|
|
for seq in seq_group.seqs:
|
|
|
|
if seq.status == SequenceStatus.FINISHED:
|
|
|
|
continue
|
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
output = seq_outputs[seq.seq_id]
|
|
|
|
if seq.seq_id != output.parent_seq_id:
|
2023-02-13 02:39:53 +00:00
|
|
|
# The sequence is a fork of the parent sequence (beam search).
|
|
|
|
# Free the current sequence.
|
|
|
|
self.block_manager.free(seq)
|
|
|
|
# Fork the parent sequence.
|
2023-03-10 09:58:21 -08:00
|
|
|
parent_seq = seq_group.find(output.parent_seq_id)
|
|
|
|
parent_seq.fork(seq)
|
2023-02-13 02:39:53 +00:00
|
|
|
self.block_manager.fork(parent_seq, seq)
|
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
# Process the next tokens.
|
|
|
|
for seq in seq_group.seqs:
|
|
|
|
if seq.status == SequenceStatus.FINISHED:
|
|
|
|
continue
|
|
|
|
|
2023-02-13 02:39:53 +00:00
|
|
|
# Append a new token to the sequence.
|
2023-03-10 09:58:21 -08:00
|
|
|
output = seq_outputs[seq.seq_id]
|
2023-05-10 00:58:31 -07:00
|
|
|
seq.append_token(output.output_token, output.logprobs)
|
2023-02-13 02:39:53 +00:00
|
|
|
|
|
|
|
# Check if the sequence has generated a stop token.
|
2023-03-10 09:58:21 -08:00
|
|
|
if output.output_token in stop_token_ids:
|
2023-02-13 02:39:53 +00:00
|
|
|
self._free_seq(seq)
|
|
|
|
continue
|
|
|
|
|
|
|
|
# Check if the sequence has reached the maximum number of steps.
|
2023-02-24 11:46:43 +00:00
|
|
|
max_num_steps = self.sampling_params[group_id].max_num_steps
|
|
|
|
if self.num_steps[group_id] == max_num_steps:
|
2023-02-13 02:39:53 +00:00
|
|
|
self._free_seq(seq)
|
|
|
|
continue
|
|
|
|
|
2023-02-23 07:54:20 +00:00
|
|
|
# Update the running sequences.
|
|
|
|
running: List[SequenceGroup] = []
|
|
|
|
for seq_group in self.running:
|
2023-02-24 11:46:43 +00:00
|
|
|
if seq_group.is_finished():
|
2023-03-29 14:48:56 +08:00
|
|
|
self._free_seq_group(seq_group)
|
2023-02-13 02:39:53 +00:00
|
|
|
else:
|
2023-02-23 07:54:20 +00:00
|
|
|
running.append(seq_group)
|
|
|
|
self.running = running
|
2023-02-24 11:46:43 +00:00
|
|
|
|
2023-03-30 14:51:46 -07:00
|
|
|
def _allocate(self, seq_group: SequenceGroup) -> None:
|
|
|
|
self.block_manager.allocate(seq_group)
|
|
|
|
for seq in seq_group.seqs:
|
|
|
|
seq.status = SequenceStatus.RUNNING
|
|
|
|
if seq_group.group_id not in self.num_steps:
|
|
|
|
self.num_steps[seq_group.group_id] = 0
|
|
|
|
|
2023-05-10 00:58:31 -07:00
|
|
|
def _append_slot(
|
2023-03-30 14:51:46 -07:00
|
|
|
self,
|
|
|
|
seq_group: SequenceGroup,
|
|
|
|
blocks_to_copy: Dict[int, List[int]],
|
|
|
|
) -> None:
|
|
|
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
2023-05-10 00:58:31 -07:00
|
|
|
ret = self.block_manager.append_slot(seq)
|
2023-03-30 14:51:46 -07:00
|
|
|
if ret is not None:
|
|
|
|
src_block, dst_block = ret
|
|
|
|
if src_block in blocks_to_copy:
|
|
|
|
blocks_to_copy[src_block].append(dst_block)
|
|
|
|
else:
|
|
|
|
blocks_to_copy[src_block] = [dst_block]
|
|
|
|
|
|
|
|
def _preempt(
|
|
|
|
self,
|
|
|
|
seq_group: SequenceGroup,
|
|
|
|
blocks_to_swap_out: Dict[int, int],
|
|
|
|
preemption_mode: Optional[PreemptionMode] = None,
|
|
|
|
) -> None:
|
|
|
|
# If preemption mode is not specified, we determine the mode as follows:
|
|
|
|
# We use recomputation by default since it incurs lower overhead than
|
|
|
|
# swapping. However, when the sequence group has multiple sequences
|
|
|
|
# (e.g., beam search), recomputation is not supported. In such a case,
|
|
|
|
# we use swapping instead.
|
|
|
|
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
|
|
|
|
# As swapped sequences are prioritized over waiting sequences,
|
|
|
|
# sequence groups with multiple sequences are implicitly prioritized
|
|
|
|
# over sequence groups with a single sequence.
|
|
|
|
# TODO(woosuk): Support recomputation for sequence groups with multiple
|
|
|
|
# sequences. This may require a more sophisticated CUDA kernel.
|
|
|
|
if preemption_mode is None:
|
|
|
|
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
|
|
|
if len(seqs) == 1:
|
|
|
|
preemption_mode = PreemptionMode.RECOMPUTE
|
|
|
|
else:
|
|
|
|
preemption_mode = PreemptionMode.SWAP
|
|
|
|
if preemption_mode == PreemptionMode.RECOMPUTE:
|
|
|
|
self._preempt_by_recompute(seq_group)
|
|
|
|
elif preemption_mode == PreemptionMode.SWAP:
|
|
|
|
self._preempt_by_swap(seq_group, blocks_to_swap_out)
|
|
|
|
else:
|
|
|
|
assert False, 'Invalid preemption mode.'
|
|
|
|
|
|
|
|
def _preempt_by_recompute(
|
|
|
|
self,
|
|
|
|
seq_group: SequenceGroup,
|
|
|
|
) -> None:
|
|
|
|
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
|
|
|
assert len(seqs) == 1
|
|
|
|
for seq in seqs:
|
|
|
|
seq.status = SequenceStatus.WAITING
|
|
|
|
self.block_manager.free(seq)
|
2023-05-10 01:57:07 -07:00
|
|
|
# NOTE: For FCFS, we insert the preempted sequence group to the front
|
|
|
|
# of the waiting queue.
|
|
|
|
self.waiting.insert(0, seq_group)
|
2023-03-30 14:51:46 -07:00
|
|
|
|
|
|
|
def _preempt_by_swap(
|
|
|
|
self,
|
|
|
|
seq_group: SequenceGroup,
|
|
|
|
blocks_to_swap_out: Dict[int, int],
|
|
|
|
) -> None:
|
|
|
|
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
|
|
|
for seq in seqs:
|
|
|
|
seq.status = SequenceStatus.SWAPPED
|
|
|
|
self._swap_out(seq_group, blocks_to_swap_out)
|
|
|
|
self.swapped.append(seq_group)
|
|
|
|
|
|
|
|
def _free_seq(self, seq: Sequence) -> None:
|
|
|
|
seq.status = SequenceStatus.FINISHED
|
|
|
|
self.block_manager.free(seq)
|
|
|
|
|
2023-03-29 14:48:56 +08:00
|
|
|
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
|
2023-02-24 11:46:43 +00:00
|
|
|
group_id = seq_group.group_id
|
|
|
|
del self.num_steps[group_id]
|
|
|
|
del self.sampling_params[group_id]
|
2023-03-30 14:51:46 -07:00
|
|
|
|
|
|
|
def _swap_in(
|
|
|
|
self,
|
|
|
|
seq_group: SequenceGroup,
|
|
|
|
blocks_to_swap_in: Dict[int, int],
|
|
|
|
) -> None:
|
|
|
|
mapping = self.block_manager.swap_in(seq_group)
|
|
|
|
blocks_to_swap_in.update(mapping)
|
|
|
|
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
|
|
|
seq.status = SequenceStatus.RUNNING
|
|
|
|
|
|
|
|
def _swap_out(
|
|
|
|
self,
|
|
|
|
seq_group: SequenceGroup,
|
|
|
|
blocks_to_swap_out: Dict[int, int],
|
|
|
|
) -> None:
|
|
|
|
assert self.block_manager.can_swap_out(seq_group)
|
|
|
|
mapping = self.block_manager.swap_out(seq_group)
|
|
|
|
blocks_to_swap_out.update(mapping)
|
|
|
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
|
|
|
seq.status = SequenceStatus.SWAPPED
|