2023-03-30 14:51:46 -07:00
|
|
|
import time
|
2023-03-10 09:58:21 -08:00
|
|
|
from typing import List, Optional, Set, Tuple
|
2023-02-24 11:46:43 +00:00
|
|
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
from cacheflow.sampling_params import SamplingParams
|
2023-03-29 14:48:56 +08:00
|
|
|
from cacheflow.sequence import Sequence, SequenceGroup
|
2023-02-24 11:46:43 +00:00
|
|
|
from cacheflow.utils import Counter
|
|
|
|
|
|
|
|
|
2023-03-29 14:48:56 +08:00
|
|
|
class SimpleFrontend:
|
2023-02-24 11:46:43 +00:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model_name: str,
|
|
|
|
block_size: int,
|
|
|
|
) -> None:
|
|
|
|
self.block_size = block_size
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
self.seq_group_counter = Counter()
|
|
|
|
self.seq_counter = Counter()
|
|
|
|
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
|
|
|
|
|
2023-03-29 14:48:56 +08:00
|
|
|
def add_eos_token(self, sampling_params: SamplingParams) -> SamplingParams:
|
|
|
|
# Stop generation when we see an EOS token.
|
|
|
|
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
|
|
|
|
return sampling_params
|
|
|
|
|
2023-02-24 11:46:43 +00:00
|
|
|
def query(
|
|
|
|
self,
|
|
|
|
prompt: str,
|
2023-03-29 14:48:56 +08:00
|
|
|
sampling_params: SamplingParams,
|
2023-02-24 11:46:43 +00:00
|
|
|
) -> None:
|
2023-03-10 09:58:21 -08:00
|
|
|
token_ids = self.tokenizer.encode(prompt)
|
|
|
|
self._add_query(token_ids, sampling_params)
|
2023-02-24 11:46:43 +00:00
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
def _add_query(
|
|
|
|
self,
|
|
|
|
token_ids: List[int],
|
|
|
|
sampling_params: SamplingParams,
|
2023-04-12 15:03:49 -07:00
|
|
|
arrival_time: Optional[float] = None,
|
2023-03-10 09:58:21 -08:00
|
|
|
) -> None:
|
2023-04-12 15:03:49 -07:00
|
|
|
if arrival_time is None:
|
|
|
|
arrival_time = time.time()
|
2023-02-24 11:46:43 +00:00
|
|
|
seqs: List[Sequence] = []
|
|
|
|
for _ in range(sampling_params.n):
|
|
|
|
seq_id = next(self.seq_counter)
|
|
|
|
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
|
|
|
|
seqs.append(seq)
|
|
|
|
|
|
|
|
group_id = next(self.seq_group_counter)
|
2023-03-30 14:51:46 -07:00
|
|
|
seq_group = SequenceGroup(group_id, seqs, arrival_time)
|
2023-02-24 11:46:43 +00:00
|
|
|
self.inputs.append((seq_group, sampling_params))
|
|
|
|
|
|
|
|
def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]:
|
|
|
|
inputs = self.inputs
|
|
|
|
self.inputs = []
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
def print_response(
|
|
|
|
self,
|
|
|
|
seq_group: SequenceGroup,
|
|
|
|
) -> None:
|
|
|
|
for seq in seq_group.seqs:
|
|
|
|
token_ids = seq.get_token_ids()
|
|
|
|
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
2023-03-29 21:25:32 -07:00
|
|
|
output = output.strip()
|
2023-02-24 11:56:06 +00:00
|
|
|
print(f'Seq {seq.seq_id}: {output!r}')
|