Add Frontend
This commit is contained in:
parent
46ce1356f7
commit
1132fae0ca
56
cacheflow/master/frontend.py
Normal file
56
cacheflow/master/frontend.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from cacheflow.sampling_params import SamplingParams
|
||||||
|
from cacheflow.sequence import Sequence
|
||||||
|
from cacheflow.sequence import SequenceGroup
|
||||||
|
from cacheflow.utils import Counter
|
||||||
|
|
||||||
|
|
||||||
|
class Frontend:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
block_size: int,
|
||||||
|
) -> None:
|
||||||
|
self.block_size = block_size
|
||||||
|
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
self.seq_group_counter = Counter()
|
||||||
|
self.seq_counter = Counter()
|
||||||
|
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
) -> None:
|
||||||
|
if sampling_params is None:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
token_ids: List[int] = self.tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
seqs: List[Sequence] = []
|
||||||
|
for _ in range(sampling_params.n):
|
||||||
|
seq_id = next(self.seq_counter)
|
||||||
|
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
|
||||||
|
seqs.append(seq)
|
||||||
|
|
||||||
|
group_id = next(self.seq_group_counter)
|
||||||
|
seq_group = SequenceGroup(group_id, seqs)
|
||||||
|
self.inputs.append((seq_group, sampling_params))
|
||||||
|
|
||||||
|
def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]:
|
||||||
|
inputs = self.inputs
|
||||||
|
self.inputs = []
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def print_response(
|
||||||
|
self,
|
||||||
|
seq_group: SequenceGroup,
|
||||||
|
) -> None:
|
||||||
|
for seq in seq_group.seqs:
|
||||||
|
token_ids = seq.get_token_ids()
|
||||||
|
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||||
|
print(f'Seq {seq.seq_id}: {output}')
|
@ -1,6 +1,8 @@
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from cacheflow.master.block_manager import BlockSpaceManager
|
from cacheflow.master.block_manager import BlockSpaceManager
|
||||||
|
from cacheflow.master.frontend import Frontend
|
||||||
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.sequence import Sequence
|
from cacheflow.sequence import Sequence
|
||||||
from cacheflow.sequence import SequenceGroup
|
from cacheflow.sequence import SequenceGroup
|
||||||
from cacheflow.sequence import SequenceStatus
|
from cacheflow.sequence import SequenceStatus
|
||||||
@ -12,11 +14,13 @@ class Scheduler:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
frontend: Frontend,
|
||||||
controllers: List,
|
controllers: List,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
num_gpu_blocks: int,
|
num_gpu_blocks: int,
|
||||||
num_cpu_blocks: int,
|
num_cpu_blocks: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.frontend = frontend
|
||||||
self.controllers = controllers
|
self.controllers = controllers
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
self.num_gpu_blocks = num_gpu_blocks
|
||||||
@ -33,16 +37,20 @@ class Scheduler:
|
|||||||
self.running: List[SequenceGroup] = []
|
self.running: List[SequenceGroup] = []
|
||||||
# Mapping: group_id -> num_steps.
|
# Mapping: group_id -> num_steps.
|
||||||
self.num_steps: Dict[int, int] = {}
|
self.num_steps: Dict[int, int] = {}
|
||||||
# Mapping: group_id -> max_num_steps.
|
# Mapping: group_id -> sampling params.
|
||||||
self.max_num_steps: Dict[int, int] = {}
|
self.sampling_params: Dict[int, SamplingParams] = {}
|
||||||
# Mapping: group_id -> stop_token_ids.
|
|
||||||
self.stop_token_ids: Dict[int, List[int]] = {}
|
|
||||||
|
|
||||||
# Swapped sequence groups (LIFO).
|
# Swapped sequence groups (LIFO).
|
||||||
self.swapped: List[SequenceGroup] = []
|
self.swapped: List[SequenceGroup] = []
|
||||||
# Pending sequence groups (FIFO).
|
# Pending sequence groups (FIFO).
|
||||||
self.pending: List[SequenceGroup] = []
|
self.pending: List[SequenceGroup] = []
|
||||||
|
|
||||||
|
def _fetch_inputs(self) -> None:
|
||||||
|
inputs = self.frontend.get_inputs()
|
||||||
|
for seq_group, sampling_params in inputs:
|
||||||
|
self.pending.append(seq_group)
|
||||||
|
self.sampling_params[seq_group.group_id] = sampling_params
|
||||||
|
|
||||||
def _free_seq(self, seq: Sequence) -> None:
|
def _free_seq(self, seq: Sequence) -> None:
|
||||||
seq.status = SequenceStatus.FINISHED
|
seq.status = SequenceStatus.FINISHED
|
||||||
self.block_manager.free(seq)
|
self.block_manager.free(seq)
|
||||||
@ -145,6 +153,7 @@ class Scheduler:
|
|||||||
# TODO(woosuk): Add a batching policy to control the batch size.
|
# TODO(woosuk): Add a batching policy to control the batch size.
|
||||||
if not self.swapped:
|
if not self.swapped:
|
||||||
# FIXME(woosuk): Acquire a lock to protect pending.
|
# FIXME(woosuk): Acquire a lock to protect pending.
|
||||||
|
self._fetch_inputs()
|
||||||
for i, seq_group in enumerate(self.pending):
|
for i, seq_group in enumerate(self.pending):
|
||||||
num_prompt_tokens = seq_group.seqs[0].get_len()
|
num_prompt_tokens = seq_group.seqs[0].get_len()
|
||||||
if self.block_manager.can_allocate(seq_group):
|
if self.block_manager.can_allocate(seq_group):
|
||||||
@ -205,7 +214,7 @@ class Scheduler:
|
|||||||
for seq_group in self.running:
|
for seq_group in self.running:
|
||||||
group_id = seq_group.group_id
|
group_id = seq_group.group_id
|
||||||
self.num_steps[group_id] += 1
|
self.num_steps[group_id] += 1
|
||||||
stop_token_ids = self.stop_token_ids[group_id]
|
stop_token_ids = self.sampling_params[group_id].stop_token_ids
|
||||||
|
|
||||||
for seq in seq_group.seqs:
|
for seq in seq_group.seqs:
|
||||||
if seq.status == SequenceStatus.FINISHED:
|
if seq.status == SequenceStatus.FINISHED:
|
||||||
@ -230,24 +239,22 @@ class Scheduler:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if the sequence has reached the maximum number of steps.
|
# Check if the sequence has reached the maximum number of steps.
|
||||||
if self.num_steps[group_id] == self.max_num_steps[group_id]:
|
max_num_steps = self.sampling_params[group_id].max_num_steps
|
||||||
|
if self.num_steps[group_id] == max_num_steps:
|
||||||
self._free_seq(seq)
|
self._free_seq(seq)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Update the running sequences.
|
# Update the running sequences.
|
||||||
running: List[SequenceGroup] = []
|
running: List[SequenceGroup] = []
|
||||||
for seq_group in self.running:
|
for seq_group in self.running:
|
||||||
if all(seq.status == SequenceStatus.FINISHED for seq in seq_group.seqs):
|
if seq_group.is_finished():
|
||||||
del self.num_steps[seq_group.group_id]
|
self._return(seq_group)
|
||||||
del self.max_num_steps[seq_group.group_id]
|
|
||||||
del self.stop_token_ids[seq_group.group_id]
|
|
||||||
# TODO: Return the seq_group to the client.
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')
|
|
||||||
for seq in seq_group.seqs:
|
|
||||||
token_ids = seq.get_token_ids()
|
|
||||||
output = tokenizer.decode(token_ids, skip_special_tokens=True)
|
|
||||||
print(f'Seq {seq.seq_id}: {output}')
|
|
||||||
else:
|
else:
|
||||||
running.append(seq_group)
|
running.append(seq_group)
|
||||||
self.running = running
|
self.running = running
|
||||||
|
|
||||||
|
def _return(self, seq_group: SequenceGroup) -> None:
|
||||||
|
group_id = seq_group.group_id
|
||||||
|
del self.num_steps[group_id]
|
||||||
|
del self.sampling_params[group_id]
|
||||||
|
self.frontend.print_response(seq_group)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user