vllm/cacheflow/master/scheduler.py

280 lines
10 KiB
Python
Raw Normal View History

from typing import Dict, List
2023-02-13 02:39:53 +00:00
from cacheflow.master.block_manager import BlockSpaceManager
2023-02-24 11:46:43 +00:00
from cacheflow.master.frontend import Frontend
from cacheflow.sampling_params import SamplingParams
2023-02-13 02:39:53 +00:00
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
2023-02-13 02:39:53 +00:00
from cacheflow.sequence import SequenceStatus
_MAX_NUM_BATCHED_TOKENS = 2048
2023-02-13 02:39:53 +00:00
class Scheduler:
2023-02-13 18:51:33 +00:00
def __init__(
2023-02-13 02:39:53 +00:00
self,
2023-02-24 11:46:43 +00:00
frontend: Frontend,
2023-02-13 09:37:00 +00:00
controllers: List,
2023-02-13 02:39:53 +00:00
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
2023-02-24 11:46:43 +00:00
self.frontend = frontend
2023-02-13 09:37:00 +00:00
self.controllers = controllers
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
# Create the block space manager.
2023-02-13 02:39:53 +00:00
self.block_manager = BlockSpaceManager(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
)
# Running sequence groups (FIFO).
self.running: List[SequenceGroup] = []
2023-02-13 02:39:53 +00:00
# Mapping: group_id -> num_steps.
self.num_steps: Dict[int, int] = {}
2023-02-24 11:46:43 +00:00
# Mapping: group_id -> sampling params.
self.sampling_params: Dict[int, SamplingParams] = {}
2023-02-13 02:39:53 +00:00
# Swapped sequence groups (LIFO).
self.swapped: List[SequenceGroup] = []
# Pending sequence groups (FIFO).
2023-02-13 09:37:00 +00:00
self.pending: List[SequenceGroup] = []
2023-02-24 11:46:43 +00:00
def _fetch_inputs(self) -> None:
inputs = self.frontend.get_inputs()
for seq_group, sampling_params in inputs:
self.pending.append(seq_group)
self.sampling_params[seq_group.group_id] = sampling_params
2023-02-13 02:39:53 +00:00
def _free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)
def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.seqs:
seq.status = SequenceStatus.RUNNING
self.running.append(seq_group)
# FIXME(woosuk): Support interactive generation.
2023-02-14 02:25:32 +00:00
self.num_steps[seq_group.group_id] = 0
2023-02-13 02:39:53 +00:00
def _append(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, List[int]],
) -> None:
2023-02-13 02:39:53 +00:00
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
ret = self.block_manager.append(seq)
if ret is not None:
src_block, dst_block = ret
if src_block in blocks_to_copy:
blocks_to_copy[src_block].append(dst_block)
else:
blocks_to_copy[src_block] = [dst_block]
2023-02-13 02:39:53 +00:00
def _swap_in(
self,
seq_group: SequenceGroup,
blocks_to_swap_in: Dict[int, int],
) -> None:
2023-02-13 09:37:00 +00:00
mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
seq.status = SequenceStatus.RUNNING
self.running.append(seq_group)
2023-02-13 02:39:53 +00:00
def _swap_out(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
) -> None:
2023-02-13 02:39:53 +00:00
assert self.block_manager.can_swap_out(seq_group)
2023-02-13 09:37:00 +00:00
mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED
2023-02-13 02:39:53 +00:00
self.swapped.append(seq_group)
2023-02-24 10:36:08 +00:00
def step(self) -> None:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
2023-02-24 10:36:08 +00:00
# 1. Reserve new slots for the running sequences.
2023-02-13 02:39:53 +00:00
# NOTE: Here we implicitly assume FCFS scheduling.
# That is, the most recently added sequence group is the first
# to be swapped out.
victim_idx = len(self.running) - 1
for i, seq_group in enumerate(self.running):
2023-02-13 02:39:53 +00:00
if i > victim_idx:
# The i-th sequence group has already been swapped out.
break
# OOM. Swap out the victim sequence groups.
while not self.block_manager.can_append(seq_group):
victim_seq_group = self.running[victim_idx]
self._swap_out(victim_seq_group, blocks_to_swap_out)
2023-02-13 02:39:53 +00:00
victim_idx -= 1
if i > victim_idx:
# No other sequence groups can be swapped out.
break
else:
self._append(seq_group, blocks_to_copy)
self.running = self.running[:victim_idx + 1]
2023-02-13 02:39:53 +00:00
# 2. Swap in the swapped sequences if possible.
# NOTE: Here we implicitly assume FCFS scheduling.
# The swapped sequences are in LIFO order.
for i, seq_group in enumerate(reversed(self.swapped)):
if self.block_manager.can_swap_in(seq_group):
self._swap_in(seq_group, blocks_to_swap_in)
self._append(seq_group, blocks_to_copy)
2023-02-13 02:39:53 +00:00
else:
# OOM. Stop swapping.
self.swapped = self.swapped[:len(self.swapped) - i]
break
else:
# All swapped sequences are swapped in.
self.swapped.clear()
# Ensure that swap-in and swap-out never happen at the same timestep.
if blocks_to_swap_in:
assert not blocks_to_swap_out
num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running
)
2023-02-13 02:39:53 +00:00
# 3. Join new sequences if possible.
# NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size.
2023-02-13 02:39:53 +00:00
if not self.swapped:
2023-02-24 11:46:43 +00:00
self._fetch_inputs()
2023-02-13 09:37:00 +00:00
for i, seq_group in enumerate(self.pending):
num_prompt_tokens = seq_group.seqs[0].get_len()
2023-02-13 02:39:53 +00:00
if self.block_manager.can_allocate(seq_group):
if (num_batched_tokens + num_prompt_tokens
<= _MAX_NUM_BATCHED_TOKENS):
self._allocate(seq_group)
num_batched_tokens += num_prompt_tokens
continue
self.pending = self.pending[i:]
break
2023-02-14 02:25:32 +00:00
else:
self.pending.clear()
2023-02-13 02:39:53 +00:00
2023-02-24 10:36:08 +00:00
# 4. Create input data structures.
input_seq_groups: List[SequenceGroupInputs] = []
for seq_group in self.running:
group_id = seq_group.group_id
num_steps = self.num_steps[group_id]
# NOTE(woosuk): We assume that the number of steps is 0
# for the prompt sequences.
is_prompt = num_steps == 0
input_tokens: Dict[int, List[int]] = {}
seq_logprobs: Dict[int, float] = {}
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
block_tables[seq_id] = self.block_manager.get_block_table(seq)
if is_prompt:
input_tokens[seq_id] = seq.get_token_ids()
else:
input_tokens[seq_id] = [seq.get_last_token_id()]
seq_logprobs[seq_id] = seq.cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len = seq.get_len()
input_seq_group = SequenceGroupInputs(
group_id=group_id,
is_prompt=is_prompt,
input_tokens=input_tokens,
context_len=seq_len,
seq_logprobs=seq_logprobs,
sampling_params=self.sampling_params[group_id],
block_tables=block_tables,
)
input_seq_groups.append(input_seq_group)
2023-02-24 10:36:08 +00:00
# 5. Execute the first stage of the pipeline.
2023-02-13 09:37:00 +00:00
self.controllers[0].execute_stage(
input_seq_groups,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
2023-02-13 09:37:00 +00:00
)
2023-02-14 02:25:32 +00:00
2023-02-13 02:39:53 +00:00
def post_step(
self,
seq_outputs: Dict[int, SequenceOutputs],
2023-02-13 02:39:53 +00:00
) -> None:
# Update the running sequences and free blocks.
for seq_group in self.running:
2023-02-13 02:39:53 +00:00
group_id = seq_group.group_id
self.num_steps[group_id] += 1
2023-02-24 11:46:43 +00:00
stop_token_ids = self.sampling_params[group_id].stop_token_ids
2023-02-13 02:39:53 +00:00
# Process beam search results before processing the next tokens.
2023-02-13 02:39:53 +00:00
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
2023-02-13 02:39:53 +00:00
# The sequence is a fork of the parent sequence (beam search).
# Free the current sequence.
self.block_manager.free(seq)
# Fork the parent sequence.
parent_seq = seq_group.find(output.parent_seq_id)
parent_seq.fork(seq)
2023-02-13 02:39:53 +00:00
self.block_manager.fork(parent_seq, seq)
# Process the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
2023-02-13 02:39:53 +00:00
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append(output.output_token, output.logprobs)
2023-02-13 02:39:53 +00:00
# Check if the sequence has generated a stop token.
if output.output_token in stop_token_ids:
2023-02-13 02:39:53 +00:00
self._free_seq(seq)
continue
# Check if the sequence has reached the maximum number of steps.
2023-02-24 11:46:43 +00:00
max_num_steps = self.sampling_params[group_id].max_num_steps
if self.num_steps[group_id] == max_num_steps:
2023-02-13 02:39:53 +00:00
self._free_seq(seq)
continue
# Update the running sequences.
running: List[SequenceGroup] = []
for seq_group in self.running:
2023-02-24 11:46:43 +00:00
if seq_group.is_finished():
self._return(seq_group)
2023-02-13 02:39:53 +00:00
else:
running.append(seq_group)
self.running = running
2023-02-24 11:46:43 +00:00
def _return(self, seq_group: SequenceGroup) -> None:
group_id = seq_group.group_id
del self.num_steps[group_id]
del self.sampling_params[group_id]
self.frontend.print_response(seq_group)