vllm/cacheflow/sequence.py

169 lines
5.2 KiB
Python
Raw Normal View History

import copy
2023-02-09 11:26:35 +00:00
import enum
from typing import Dict, List, Optional
2023-02-09 11:26:35 +00:00
from cacheflow.block import LogicalTokenBlock
from cacheflow.sampling_params import SamplingParams
2023-02-09 11:26:35 +00:00
class SequenceStatus(enum.Enum):
WAITING = enum.auto()
2023-02-12 08:25:05 +00:00
RUNNING = enum.auto()
2023-02-09 11:26:35 +00:00
SWAPPED = enum.auto()
FINISHED = enum.auto()
class Sequence:
def __init__(
self,
seq_id: int,
token_ids: List[int],
block_size: int,
) -> None:
self.seq_id = seq_id
self.block_size = block_size
self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the given token ids.
self.add(token_ids)
2023-02-09 11:26:35 +00:00
self.status = SequenceStatus.WAITING
self.output_logprobs: List[Dict[int, float]] = []
2023-03-26 08:00:39 +00:00
self.cumulative_logprobs = 0.0
2023-02-09 11:26:35 +00:00
def add_block(self) -> None:
block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks),
block_size=self.block_size,
)
self.logical_token_blocks.append(block)
def add(self, token_ids: List[int]) -> None:
2023-02-09 11:26:35 +00:00
while token_ids:
if not self.logical_token_blocks:
self.add_block()
last_block = self.logical_token_blocks[-1]
if last_block.is_full():
self.add_block()
last_block = self.logical_token_blocks[-1]
num_empty_slots = last_block.get_num_empty_slots()
last_block.append(token_ids[:num_empty_slots])
token_ids = token_ids[num_empty_slots:]
def append(self, token_id: int, logprobs: Dict[int, float]) -> None:
assert token_id in logprobs
self.add([token_id])
self.output_logprobs.append(logprobs)
self.cumulative_logprobs += logprobs[token_id]
2023-02-23 05:58:04 +00:00
def get_len(self) -> int:
return sum(block.num_tokens for block in self.logical_token_blocks)
2023-02-09 11:26:35 +00:00
def get_token_ids(self) -> List[int]:
token_ids: List[int] = []
for block in self.logical_token_blocks:
token_ids.extend(block.get_token_ids())
return token_ids
def get_last_token_id(self) -> int:
return self.logical_token_blocks[-1].get_last_token_id()
def fork(self, child_seq: 'Sequence') -> 'Sequence':
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
child_seq.cumulative_logprobs = self.cumulative_logprobs
2023-02-14 09:34:07 +00:00
def __repr__(self) -> str:
return (f'Sequence(seq_id={self.seq_id}, '
f'status={self.status.name}, '
f'num_blocks={len(self.logical_token_blocks)})')
2023-02-09 11:26:35 +00:00
class SequenceGroup:
def __init__(
self,
group_id: int,
seqs: List[Sequence],
arrival_time: float,
2023-02-09 11:26:35 +00:00
) -> None:
self.group_id = group_id
self.seqs = seqs
self.arrival_time = arrival_time
2023-02-09 11:26:35 +00:00
def get_seqs(
self,
status: Optional[SequenceStatus] = None,
) -> List[Sequence]:
2023-02-09 11:26:35 +00:00
if status is None:
return self.seqs
2023-02-09 11:26:35 +00:00
else:
return [seq for seq in self.seqs if seq.status == status]
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status))
2023-02-13 02:39:12 +00:00
def find(self, seq_id: int) -> Sequence:
for seq in self.seqs:
if seq.seq_id == seq_id:
return seq
raise ValueError(f'Sequence {seq_id} not found.')
2023-02-14 09:34:07 +00:00
2023-02-24 11:44:21 +00:00
def is_finished(self) -> bool:
return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs)
2023-02-14 09:34:07 +00:00
def __repr__(self) -> str:
return (f'SequenceGroup(group_id={self.group_id}, '
f'num_seqs={len(self.seqs)})')
class SequenceGroupInputs:
def __init__(
self,
group_id: int,
is_prompt: bool,
input_tokens: Dict[int, List[int]], # Seq id -> token ids.
context_len: int,
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
) -> None:
self.group_id = group_id
self.is_prompt = is_prompt
self.input_tokens = input_tokens
self.context_len = context_len
self.seq_logprobs = seq_logprobs
self.sampling_params = sampling_params
self.block_tables = block_tables
class SequenceOutputs:
def __init__(
self,
seq_id: int,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i).
) -> None:
self.seq_id = seq_id
self.parent_seq_id = parent_seq_id
self.output_token = output_token
self.logprobs = logprobs
def __repr__(self) -> str:
return (f'SequenceOutputs(seq_id={self.seq_id}, '
f'parent_seq_id={self.parent_seq_id}, '
f'output_token={self.output_token}), '
f'logprobs={self.logprobs}')
2023-03-22 04:45:42 +08:00
def __eq__(self, other: 'SequenceOutputs') -> bool:
return (self.seq_id == other.seq_id and
self.parent_seq_id == other.parent_seq_id and
self.output_token == other.output_token and
self.logprobs == other.logprobs)