Add Frontend

This commit is contained in:
Woosuk Kwon 2023-02-24 11:46:43 +00:00
parent 46ce1356f7
commit 1132fae0ca
2 changed files with 80 additions and 17 deletions

View File

@ -0,0 +1,56 @@
from typing import List, Optional, Tuple
from transformers import AutoTokenizer
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.utils import Counter
class Frontend:
def __init__(
self,
model_name: str,
block_size: int,
) -> None:
self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.seq_group_counter = Counter()
self.seq_counter = Counter()
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
def query(
self,
prompt: str,
sampling_params: Optional[SamplingParams] = None,
) -> None:
if sampling_params is None:
sampling_params = SamplingParams()
token_ids: List[int] = self.tokenizer.encode(prompt)
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
seqs.append(seq)
group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs)
self.inputs.append((seq_group, sampling_params))
def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]:
inputs = self.inputs
self.inputs = []
return inputs
def print_response(
self,
seq_group: SequenceGroup,
) -> None:
for seq in seq_group.seqs:
token_ids = seq.get_token_ids()
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
print(f'Seq {seq.seq_id}: {output}')

View File

@ -1,6 +1,8 @@
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from cacheflow.master.block_manager import BlockSpaceManager from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.frontend import Frontend
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceStatus from cacheflow.sequence import SequenceStatus
@ -12,11 +14,13 @@ class Scheduler:
def __init__( def __init__(
self, self,
frontend: Frontend,
controllers: List, controllers: List,
block_size: int, block_size: int,
num_gpu_blocks: int, num_gpu_blocks: int,
num_cpu_blocks: int, num_cpu_blocks: int,
) -> None: ) -> None:
self.frontend = frontend
self.controllers = controllers self.controllers = controllers
self.block_size = block_size self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks self.num_gpu_blocks = num_gpu_blocks
@ -33,16 +37,20 @@ class Scheduler:
self.running: List[SequenceGroup] = [] self.running: List[SequenceGroup] = []
# Mapping: group_id -> num_steps. # Mapping: group_id -> num_steps.
self.num_steps: Dict[int, int] = {} self.num_steps: Dict[int, int] = {}
# Mapping: group_id -> max_num_steps. # Mapping: group_id -> sampling params.
self.max_num_steps: Dict[int, int] = {} self.sampling_params: Dict[int, SamplingParams] = {}
# Mapping: group_id -> stop_token_ids.
self.stop_token_ids: Dict[int, List[int]] = {}
# Swapped sequence groups (LIFO). # Swapped sequence groups (LIFO).
self.swapped: List[SequenceGroup] = [] self.swapped: List[SequenceGroup] = []
# Pending sequence groups (FIFO). # Pending sequence groups (FIFO).
self.pending: List[SequenceGroup] = [] self.pending: List[SequenceGroup] = []
def _fetch_inputs(self) -> None:
inputs = self.frontend.get_inputs()
for seq_group, sampling_params in inputs:
self.pending.append(seq_group)
self.sampling_params[seq_group.group_id] = sampling_params
def _free_seq(self, seq: Sequence) -> None: def _free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq) self.block_manager.free(seq)
@ -145,6 +153,7 @@ class Scheduler:
# TODO(woosuk): Add a batching policy to control the batch size. # TODO(woosuk): Add a batching policy to control the batch size.
if not self.swapped: if not self.swapped:
# FIXME(woosuk): Acquire a lock to protect pending. # FIXME(woosuk): Acquire a lock to protect pending.
self._fetch_inputs()
for i, seq_group in enumerate(self.pending): for i, seq_group in enumerate(self.pending):
num_prompt_tokens = seq_group.seqs[0].get_len() num_prompt_tokens = seq_group.seqs[0].get_len()
if self.block_manager.can_allocate(seq_group): if self.block_manager.can_allocate(seq_group):
@ -205,7 +214,7 @@ class Scheduler:
for seq_group in self.running: for seq_group in self.running:
group_id = seq_group.group_id group_id = seq_group.group_id
self.num_steps[group_id] += 1 self.num_steps[group_id] += 1
stop_token_ids = self.stop_token_ids[group_id] stop_token_ids = self.sampling_params[group_id].stop_token_ids
for seq in seq_group.seqs: for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED: if seq.status == SequenceStatus.FINISHED:
@ -230,24 +239,22 @@ class Scheduler:
continue continue
# Check if the sequence has reached the maximum number of steps. # Check if the sequence has reached the maximum number of steps.
if self.num_steps[group_id] == self.max_num_steps[group_id]: max_num_steps = self.sampling_params[group_id].max_num_steps
if self.num_steps[group_id] == max_num_steps:
self._free_seq(seq) self._free_seq(seq)
continue continue
# Update the running sequences. # Update the running sequences.
running: List[SequenceGroup] = [] running: List[SequenceGroup] = []
for seq_group in self.running: for seq_group in self.running:
if all(seq.status == SequenceStatus.FINISHED for seq in seq_group.seqs): if seq_group.is_finished():
del self.num_steps[seq_group.group_id] self._return(seq_group)
del self.max_num_steps[seq_group.group_id]
del self.stop_token_ids[seq_group.group_id]
# TODO: Return the seq_group to the client.
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')
for seq in seq_group.seqs:
token_ids = seq.get_token_ids()
output = tokenizer.decode(token_ids, skip_special_tokens=True)
print(f'Seq {seq.seq_id}: {output}')
else: else:
running.append(seq_group) running.append(seq_group)
self.running = running self.running = running
def _return(self, seq_group: SequenceGroup) -> None:
group_id = seq_group.group_id
del self.num_steps[group_id]
del self.sampling_params[group_id]
self.frontend.print_response(seq_group)