Add Frontend
This commit is contained in:
parent
46ce1356f7
commit
1132fae0ca
56
cacheflow/master/frontend.py
Normal file
56
cacheflow/master/frontend.py
Normal file
@ -0,0 +1,56 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import Sequence
|
||||
from cacheflow.sequence import SequenceGroup
|
||||
from cacheflow.utils import Counter
|
||||
|
||||
|
||||
class Frontend:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.seq_group_counter = Counter()
|
||||
self.seq_counter = Counter()
|
||||
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
|
||||
|
||||
def query(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
) -> None:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
token_ids: List[int] = self.tokenizer.encode(prompt)
|
||||
|
||||
seqs: List[Sequence] = []
|
||||
for _ in range(sampling_params.n):
|
||||
seq_id = next(self.seq_counter)
|
||||
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
|
||||
seqs.append(seq)
|
||||
|
||||
group_id = next(self.seq_group_counter)
|
||||
seq_group = SequenceGroup(group_id, seqs)
|
||||
self.inputs.append((seq_group, sampling_params))
|
||||
|
||||
def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]:
|
||||
inputs = self.inputs
|
||||
self.inputs = []
|
||||
return inputs
|
||||
|
||||
def print_response(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
) -> None:
|
||||
for seq in seq_group.seqs:
|
||||
token_ids = seq.get_token_ids()
|
||||
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||
print(f'Seq {seq.seq_id}: {output}')
|
@ -1,6 +1,8 @@
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from cacheflow.master.block_manager import BlockSpaceManager
|
||||
from cacheflow.master.frontend import Frontend
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import Sequence
|
||||
from cacheflow.sequence import SequenceGroup
|
||||
from cacheflow.sequence import SequenceStatus
|
||||
@ -12,11 +14,13 @@ class Scheduler:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frontend: Frontend,
|
||||
controllers: List,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
) -> None:
|
||||
self.frontend = frontend
|
||||
self.controllers = controllers
|
||||
self.block_size = block_size
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
@ -33,16 +37,20 @@ class Scheduler:
|
||||
self.running: List[SequenceGroup] = []
|
||||
# Mapping: group_id -> num_steps.
|
||||
self.num_steps: Dict[int, int] = {}
|
||||
# Mapping: group_id -> max_num_steps.
|
||||
self.max_num_steps: Dict[int, int] = {}
|
||||
# Mapping: group_id -> stop_token_ids.
|
||||
self.stop_token_ids: Dict[int, List[int]] = {}
|
||||
# Mapping: group_id -> sampling params.
|
||||
self.sampling_params: Dict[int, SamplingParams] = {}
|
||||
|
||||
# Swapped sequence groups (LIFO).
|
||||
self.swapped: List[SequenceGroup] = []
|
||||
# Pending sequence groups (FIFO).
|
||||
self.pending: List[SequenceGroup] = []
|
||||
|
||||
def _fetch_inputs(self) -> None:
|
||||
inputs = self.frontend.get_inputs()
|
||||
for seq_group, sampling_params in inputs:
|
||||
self.pending.append(seq_group)
|
||||
self.sampling_params[seq_group.group_id] = sampling_params
|
||||
|
||||
def _free_seq(self, seq: Sequence) -> None:
|
||||
seq.status = SequenceStatus.FINISHED
|
||||
self.block_manager.free(seq)
|
||||
@ -145,6 +153,7 @@ class Scheduler:
|
||||
# TODO(woosuk): Add a batching policy to control the batch size.
|
||||
if not self.swapped:
|
||||
# FIXME(woosuk): Acquire a lock to protect pending.
|
||||
self._fetch_inputs()
|
||||
for i, seq_group in enumerate(self.pending):
|
||||
num_prompt_tokens = seq_group.seqs[0].get_len()
|
||||
if self.block_manager.can_allocate(seq_group):
|
||||
@ -205,7 +214,7 @@ class Scheduler:
|
||||
for seq_group in self.running:
|
||||
group_id = seq_group.group_id
|
||||
self.num_steps[group_id] += 1
|
||||
stop_token_ids = self.stop_token_ids[group_id]
|
||||
stop_token_ids = self.sampling_params[group_id].stop_token_ids
|
||||
|
||||
for seq in seq_group.seqs:
|
||||
if seq.status == SequenceStatus.FINISHED:
|
||||
@ -230,24 +239,22 @@ class Scheduler:
|
||||
continue
|
||||
|
||||
# Check if the sequence has reached the maximum number of steps.
|
||||
if self.num_steps[group_id] == self.max_num_steps[group_id]:
|
||||
max_num_steps = self.sampling_params[group_id].max_num_steps
|
||||
if self.num_steps[group_id] == max_num_steps:
|
||||
self._free_seq(seq)
|
||||
continue
|
||||
|
||||
# Update the running sequences.
|
||||
running: List[SequenceGroup] = []
|
||||
for seq_group in self.running:
|
||||
if all(seq.status == SequenceStatus.FINISHED for seq in seq_group.seqs):
|
||||
del self.num_steps[seq_group.group_id]
|
||||
del self.max_num_steps[seq_group.group_id]
|
||||
del self.stop_token_ids[seq_group.group_id]
|
||||
# TODO: Return the seq_group to the client.
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')
|
||||
for seq in seq_group.seqs:
|
||||
token_ids = seq.get_token_ids()
|
||||
output = tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||
print(f'Seq {seq.seq_id}: {output}')
|
||||
if seq_group.is_finished():
|
||||
self._return(seq_group)
|
||||
else:
|
||||
running.append(seq_group)
|
||||
self.running = running
|
||||
|
||||
def _return(self, seq_group: SequenceGroup) -> None:
|
||||
group_id = seq_group.group_id
|
||||
del self.num_steps[group_id]
|
||||
del self.sampling_params[group_id]
|
||||
self.frontend.print_response(seq_group)
|
||||
|
Loading…
x
Reference in New Issue
Block a user