900 lines
33 KiB
Python
900 lines
33 KiB
Python
"""Sequence and its related classes."""
|
|
import copy
|
|
import enum
|
|
import math
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.sampling_params import SamplingParams
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.inputs import LLMInputs
|
|
from vllm.multimodal import MultiModalData
|
|
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
|
|
|
|
|
@dataclass
|
|
class Logprob:
|
|
"""Infos for supporting OpenAI compatible logprobs and token ranks.
|
|
|
|
Attributes:
|
|
logprob: The logprob of chosen token
|
|
rank: The vocab rank of chosen token (>=1)
|
|
decoded_token: The decoded chosen token index
|
|
"""
|
|
logprob: float
|
|
rank: Optional[int] = None
|
|
decoded_token: Optional[str] = None
|
|
|
|
|
|
# {token_id -> logprob} per each sequence group. None if the corresponding
|
|
# sequence group doesn't require prompt logprob.
|
|
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
|
|
# {token_id -> logprob} for each sequence group.
|
|
SampleLogprobs = List[Dict[int, Logprob]]
|
|
|
|
|
|
class SequenceStatus(enum.Enum):
|
|
"""Status of a sequence."""
|
|
WAITING = enum.auto()
|
|
RUNNING = enum.auto()
|
|
SWAPPED = enum.auto()
|
|
FINISHED_STOPPED = enum.auto()
|
|
FINISHED_LENGTH_CAPPED = enum.auto()
|
|
FINISHED_ABORTED = enum.auto()
|
|
FINISHED_IGNORED = enum.auto()
|
|
|
|
@staticmethod
|
|
def is_finished(status: "SequenceStatus") -> bool:
|
|
return status in [
|
|
SequenceStatus.FINISHED_STOPPED,
|
|
SequenceStatus.FINISHED_LENGTH_CAPPED,
|
|
SequenceStatus.FINISHED_ABORTED,
|
|
SequenceStatus.FINISHED_IGNORED,
|
|
]
|
|
|
|
@staticmethod
|
|
def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
|
|
if status == SequenceStatus.FINISHED_STOPPED:
|
|
finish_reason = "stop"
|
|
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
|
|
finish_reason = "length"
|
|
elif status == SequenceStatus.FINISHED_ABORTED:
|
|
finish_reason = "abort"
|
|
elif status == SequenceStatus.FINISHED_IGNORED:
|
|
# The ignored sequences are the sequences whose prompt lengths
|
|
# are longer than the model's length cap. Therefore, the stop
|
|
# reason should also be "length" as in OpenAI API.
|
|
finish_reason = "length"
|
|
else:
|
|
finish_reason = None
|
|
return finish_reason
|
|
|
|
|
|
class SequenceStage(enum.Enum):
|
|
PREFILL = enum.auto()
|
|
DECODE = enum.auto()
|
|
|
|
|
|
@dataclass
|
|
class RequestMetrics:
|
|
"""Metrics associated with a request.
|
|
|
|
Attributes:
|
|
arrival_time: The time when the request arrived.
|
|
first_scheduled_time: The time when the request was first scheduled.
|
|
first_token_time: The time when the first token was generated.
|
|
time_in_queue: The time the request spent in the queue.
|
|
finished_time: The time when the request was finished.
|
|
"""
|
|
arrival_time: float
|
|
last_token_time: float
|
|
first_scheduled_time: Optional[float]
|
|
first_token_time: Optional[float]
|
|
time_in_queue: Optional[float]
|
|
finished_time: Optional[float] = None
|
|
|
|
|
|
class SequenceData:
|
|
"""Data associated with a sequence.
|
|
|
|
Args:
|
|
prompt_token_ids: The token IDs of the prompt.
|
|
output_token_ids: The token IDs of the output. Set to an empty list if
|
|
None.
|
|
|
|
Attributes:
|
|
prompt_token_ids: The token IDs of the prompt.
|
|
output_token_ids: The token IDs of the output.
|
|
cumulative_logprob: The cumulative log probability of the output.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
prompt_token_ids: List[int],
|
|
output_token_ids: Optional[List[int]] = None,
|
|
) -> None:
|
|
if output_token_ids is None:
|
|
output_token_ids = []
|
|
|
|
self.prompt_token_ids = prompt_token_ids
|
|
self._prompt_token_ids_tuple = tuple(prompt_token_ids)
|
|
self.output_token_ids = output_token_ids
|
|
self.cumulative_logprob = 0.0
|
|
# The number of tokens that are computed (that run against the model).
|
|
self._num_computed_tokens = 0
|
|
self._stage: SequenceStage = SequenceStage.PREFILL
|
|
|
|
def append_token_id(self, token_id: int, logprob: float) -> None:
|
|
self.output_token_ids.append(token_id)
|
|
self.cumulative_logprob += logprob
|
|
|
|
def get_len(self) -> int:
|
|
return len(self.output_token_ids) + len(self.prompt_token_ids)
|
|
|
|
def get_prompt_len(self) -> int:
|
|
return len(self.prompt_token_ids)
|
|
|
|
def get_output_len(self) -> int:
|
|
return len(self.output_token_ids)
|
|
|
|
def get_token_ids(self) -> List[int]:
|
|
return self.prompt_token_ids + self.output_token_ids
|
|
|
|
def get_prefix_token_ids(
|
|
self, num_tokens: int
|
|
) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
|
|
"""Get prefix tokens, and make the return value hashable"""
|
|
prompt_length = len(self.prompt_token_ids)
|
|
if num_tokens > prompt_length:
|
|
return (self._prompt_token_ids_tuple,
|
|
tuple(self.output_token_ids[:num_tokens - prompt_length]))
|
|
else:
|
|
return (self._prompt_token_ids_tuple[:num_tokens], None)
|
|
|
|
def get_num_computed_tokens(self) -> int:
|
|
"""Return the number of prefill tokens that are already computed."""
|
|
return self._num_computed_tokens
|
|
|
|
def update_num_computed_tokens(self, num_new_computed_tokens: int):
|
|
"""Update number of tokens computed so far."""
|
|
self._num_computed_tokens += num_new_computed_tokens
|
|
assert self._num_computed_tokens <= self.get_len(), (
|
|
self._num_computed_tokens, self.get_len())
|
|
# If all tokens are computed, it means it is in decoding phase.
|
|
if self.get_num_uncomputed_tokens() == 0:
|
|
self._stage = SequenceStage.DECODE
|
|
|
|
def reset_state_for_recompute(self) -> None:
|
|
"""Reset the number of computed tokens from this sequence. It is
|
|
supposed to be called when a sequence needs to be started from
|
|
the beginning again (e.g., sequence is preempted).
|
|
"""
|
|
self._num_computed_tokens = 0
|
|
self._stage = SequenceStage.PREFILL
|
|
|
|
def get_num_uncomputed_tokens(self) -> int:
|
|
"""Return the number of prefill tokens that are not computed."""
|
|
# we use `get_len()` which includes prompt_len + output_len instead
|
|
# of prompt_len here. This is because during recompute we need to
|
|
# prefill for both prompt and output.
|
|
return self.get_len() - self.get_num_computed_tokens()
|
|
|
|
def get_last_token_id(self) -> int:
|
|
if not self.output_token_ids:
|
|
return self.prompt_token_ids[-1]
|
|
return self.output_token_ids[-1]
|
|
|
|
def get_prompt_token_ids(self) -> List[int]:
|
|
return self.prompt_token_ids
|
|
|
|
def get_output_token_ids(self) -> List[int]:
|
|
return self.output_token_ids
|
|
|
|
@property
|
|
def stage(self) -> SequenceStage:
|
|
return self._stage
|
|
|
|
def __repr__(self) -> str:
|
|
return (f"SequenceData("
|
|
f"prompt_token_ids={self.prompt_token_ids}, "
|
|
f"output_token_ids={self.output_token_ids}, "
|
|
f"cumulative_logprob={self.cumulative_logprob})")
|
|
|
|
|
|
class Sequence:
|
|
"""Stores the data, status, and block information of a sequence.
|
|
|
|
Args:
|
|
seq_id: The ID of the sequence.
|
|
inputs: The inputs of the sequence.
|
|
block_size: The block size of the sequence. Should be the same as the
|
|
block size used by the block manager and cache engine.
|
|
lora_request: LoRA request.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
seq_id: int,
|
|
inputs: "LLMInputs",
|
|
block_size: int,
|
|
eos_token_id: Optional[int] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
) -> None:
|
|
self.seq_id = seq_id
|
|
self.inputs = inputs
|
|
self.block_size = block_size
|
|
self.eos_token_id = eos_token_id
|
|
self.lora_request = lora_request
|
|
|
|
self.data = SequenceData(self.prompt_token_ids)
|
|
self.output_logprobs: SampleLogprobs = []
|
|
self.output_text = ""
|
|
|
|
self.status = SequenceStatus.WAITING
|
|
self.stop_reason: Union[int, str, None] = None
|
|
|
|
# Used for incremental detokenization
|
|
self.prefix_offset = 0
|
|
self.read_offset = 0
|
|
# Input + output tokens
|
|
self.tokens: Optional[List[str]] = None
|
|
|
|
@property
|
|
def n_blocks(self) -> int:
|
|
return math.ceil(self.get_len() / self.block_size)
|
|
|
|
@property
|
|
def prompt(self) -> Optional[str]:
|
|
return self.inputs.get("prompt")
|
|
|
|
@property
|
|
def prompt_token_ids(self) -> List[int]:
|
|
return self.inputs["prompt_token_ids"]
|
|
|
|
@property
|
|
def multi_modal_data(self) -> Optional["MultiModalData"]:
|
|
return self.inputs.get("multi_modal_data")
|
|
|
|
@property
|
|
def lora_int_id(self) -> int:
|
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
|
|
|
def get_output_text_to_return(self, buffer_length: int):
|
|
# We return the full output text if the sequence is finished.
|
|
truncate = buffer_length and not self.is_finished()
|
|
return self.output_text[:-buffer_length] if truncate else (
|
|
self.output_text)
|
|
|
|
def hash_of_block(self, logical_idx: int) -> int:
|
|
# TODO This can produce incorrect hash when block size > prompt size
|
|
|
|
# Compute the number of tokens in the sequence
|
|
# TODO: The current hashing function is O(L^2). We should optimize
|
|
# this in the future.
|
|
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
|
|
hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
|
|
return hash((hashed_tokens, self.lora_int_id))
|
|
|
|
def num_hashed_tokens_of_block(self, logical_idx: int):
|
|
return logical_idx * self.block_size + self.block_size
|
|
|
|
def reset_state_for_recompute(self):
|
|
"""Reset the sequence states for recomputation."""
|
|
self.data.reset_state_for_recompute()
|
|
|
|
def append_token_id(
|
|
self,
|
|
token_id: int,
|
|
logprobs: Dict[int, Logprob],
|
|
) -> None:
|
|
assert token_id in logprobs
|
|
self.output_logprobs.append(logprobs)
|
|
self.data.append_token_id(token_id, logprobs[token_id].logprob)
|
|
|
|
def get_len(self) -> int:
|
|
return self.data.get_len()
|
|
|
|
def get_prompt_len(self) -> int:
|
|
return self.data.get_prompt_len()
|
|
|
|
def get_output_len(self) -> int:
|
|
return self.data.get_output_len()
|
|
|
|
def get_token_ids(self) -> List[int]:
|
|
return self.data.get_token_ids()
|
|
|
|
def get_prompt_token_ids(self) -> List[int]:
|
|
return self.data.get_prompt_token_ids()
|
|
|
|
def get_last_token_id(self) -> int:
|
|
return self.data.get_last_token_id()
|
|
|
|
def get_output_token_ids(self) -> List[int]:
|
|
return self.data.output_token_ids
|
|
|
|
def get_cumulative_logprob(self) -> float:
|
|
return self.data.cumulative_logprob
|
|
|
|
def get_beam_search_score(self,
|
|
length_penalty: float = 1.0,
|
|
seq_len: Optional[int] = None,
|
|
eos_token_id: Optional[int] = None) -> float:
|
|
"""Calculate the beam search score with length penalty.
|
|
|
|
Adapted from
|
|
|
|
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
|
"""
|
|
if seq_len is None:
|
|
seq_len = self.get_len()
|
|
# NOTE: HF implementation does not count the EOS token
|
|
# towards the length, we align with that here for testing.
|
|
if (eos_token_id is not None
|
|
and self.get_last_token_id() == eos_token_id):
|
|
seq_len -= 1
|
|
return self.get_cumulative_logprob() / (seq_len**length_penalty)
|
|
|
|
def is_finished(self) -> bool:
|
|
return SequenceStatus.is_finished(self.status)
|
|
|
|
def fork(self, new_seq_id: int) -> "Sequence":
|
|
new_seq = copy.deepcopy(self)
|
|
new_seq.seq_id = new_seq_id
|
|
return new_seq
|
|
|
|
def get_num_new_tokens(self) -> int:
|
|
"""Get the number of new tokens to be computed.
|
|
|
|
Returns:
|
|
The new number of tokens to be computed. I.e., 1 for decode, or
|
|
the remaining prompt size for prefill.
|
|
"""
|
|
if self.data.stage == SequenceStage.DECODE:
|
|
return 1
|
|
return self.data.get_num_uncomputed_tokens()
|
|
|
|
def is_prefill(self) -> bool:
|
|
return self.data.stage == SequenceStage.PREFILL
|
|
|
|
def __repr__(self) -> str:
|
|
return (f"Sequence(seq_id={self.seq_id}, "
|
|
f"status={self.status.name}, "
|
|
f"num_blocks={self.n_blocks}, ")
|
|
|
|
|
|
@dataclass
|
|
class SequenceGroupState:
|
|
"""Mutable state tied to a specific sequence group"""
|
|
|
|
# torch.Generator used in seeded sampling
|
|
generator: Optional = None # type: ignore
|
|
|
|
|
|
class SequenceGroup:
|
|
"""A group of sequences that are generated from the same prompt.
|
|
|
|
Args:
|
|
request_id: The ID of the request.
|
|
seqs: The list of sequences.
|
|
sampling_params: The sampling parameters used to generate the outputs.
|
|
arrival_time: The arrival time of the request.
|
|
lora_request: LoRA request.
|
|
embeddings: The embeddings vectors of the prompt of the sequence group
|
|
for an embedding model.
|
|
pooling_params: The pooling parameters used to generate the pooling
|
|
for an embedding model.
|
|
encoder_seq: Optional, the single encoder sequence. Should be None
|
|
unless you are working with an encoder/decoder model.
|
|
trace_headers: OpenTelemetry trace headers.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
request_id: str,
|
|
seqs: List[Sequence],
|
|
arrival_time: float,
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
embeddings: Optional[List[float]] = None,
|
|
pooling_params: Optional[PoolingParams] = None,
|
|
encoder_seq: Optional[Sequence] = None,
|
|
trace_headers: Optional[Dict[str, str]] = None,
|
|
) -> None:
|
|
self.request_id = request_id
|
|
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
|
self.sampling_params = sampling_params
|
|
self.metrics = RequestMetrics(arrival_time=arrival_time,
|
|
last_token_time=arrival_time,
|
|
first_scheduled_time=None,
|
|
first_token_time=None,
|
|
time_in_queue=None)
|
|
self.lora_request = lora_request
|
|
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
|
self.state = SequenceGroupState()
|
|
self.embeddings = embeddings
|
|
self.pooling_params = pooling_params
|
|
self.encoder_seq = encoder_seq
|
|
self.trace_headers = trace_headers
|
|
|
|
@property
|
|
def prompt(self) -> Optional[str]:
|
|
# All sequences in the group should have the same prompt.
|
|
# We use the prompt of an arbitrary sequence.
|
|
return next(iter(self.seqs_dict.values())).prompt
|
|
|
|
@property
|
|
def prompt_token_ids(self) -> List[int]:
|
|
# All sequences in the group should have the same prompt.
|
|
# We use the prompt of an arbitrary sequence.
|
|
return next(iter(self.seqs_dict.values())).prompt_token_ids
|
|
|
|
@property
|
|
def multi_modal_data(self) -> Optional["MultiModalData"]:
|
|
# All sequences in the group should have the same multi-modal data.
|
|
# We use the multi-modal data of an arbitrary sequence.
|
|
return next(iter(self.seqs_dict.values())).multi_modal_data
|
|
|
|
@property
|
|
def lora_int_id(self) -> int:
|
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
|
|
|
def get_last_latency(self, now: float) -> Optional[float]:
|
|
"""Sets the last token time for Request level timings."""
|
|
# If still in prefill phase, raise Error.
|
|
if self.is_prefill():
|
|
raise ValueError(
|
|
"seq_group.get_last_latency() should not be called "
|
|
"if the seq_group is in prefill phase.")
|
|
|
|
# Otherwise return token latency.
|
|
latency = now - self.metrics.last_token_time
|
|
self.metrics.last_token_time = now
|
|
return latency
|
|
|
|
def maybe_set_first_token_time(self, time: float) -> None:
|
|
"""Sets the first token time for Request level timings."""
|
|
# Note: in a case where a sequence_group is swapped and
|
|
# recomputed, the time between iterations is counted
|
|
# in TPOT, rather than recalculating TTFT (since from the )
|
|
# POV of the user, there is simply a long generation delay.
|
|
if (self.metrics.first_token_time is None
|
|
and self.get_seqs()[0].get_output_len() == 1):
|
|
self.metrics.first_token_time = time
|
|
|
|
def maybe_set_first_scheduled_time(self, time: float) -> None:
|
|
"""Sets the first scheduled time and time in queue for Request
|
|
level timings."""
|
|
if self.metrics.first_scheduled_time is None:
|
|
self.metrics.first_scheduled_time = time
|
|
self.metrics.time_in_queue = time - self.metrics.arrival_time
|
|
|
|
def set_finished_time(self, time: Optional[float]) -> None:
|
|
"""Sets the finished time for Request level timings."""
|
|
self.metrics.finished_time = time
|
|
|
|
def get_max_num_running_seqs(self) -> int:
|
|
"""The maximum number of sequences running in parallel in the remaining
|
|
lifetime of the request."""
|
|
if self.sampling_params and self.sampling_params.use_beam_search:
|
|
# For beam search, maximally there will always be `best_of` beam
|
|
# candidates running in the future.
|
|
return self.sampling_params.best_of
|
|
else:
|
|
if (self.sampling_params
|
|
and self.sampling_params.best_of > self.num_seqs()):
|
|
# At prompt stage, the sequence group is not yet filled up
|
|
# and only have one sequence running. However, in the
|
|
# generation stage, we will have `best_of` sequences running.
|
|
return self.sampling_params.best_of
|
|
# At sampling stages, return the number of actual sequences
|
|
# that are not finished yet.
|
|
return self.num_unfinished_seqs()
|
|
|
|
def get_seqs(
|
|
self,
|
|
status: Optional[SequenceStatus] = None,
|
|
) -> List[Sequence]:
|
|
return list(self.seqs_dict.values()) if status is None else [
|
|
seq for seq in self.seqs_dict.values() if seq.status == status
|
|
]
|
|
|
|
def is_encoder_decoder(self) -> bool:
|
|
return self.encoder_seq is not None
|
|
|
|
def get_encoder_seq(self) -> Optional[Sequence]:
|
|
return self.encoder_seq
|
|
|
|
def get_unfinished_seqs(self) -> List[Sequence]:
|
|
return [
|
|
seq for seq in self.seqs_dict.values() if not seq.is_finished()
|
|
]
|
|
|
|
def get_finished_seqs(self) -> List[Sequence]:
|
|
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
|
|
|
def update_num_computed_tokens(self, num_new_computed_tokens: int):
|
|
"""Update number of tokens computed so far."""
|
|
for seq in self.seqs_dict.values():
|
|
if not seq.is_finished():
|
|
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
|
|
|
def get_num_uncomputed_tokens(self) -> int:
|
|
num_uncomputed_tokens = 0
|
|
for seq in self.get_seqs():
|
|
if not seq.is_finished():
|
|
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
|
return num_uncomputed_tokens
|
|
|
|
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
|
# Optimization. We don't need to call get_seqs if we don't need to
|
|
# filter by states.
|
|
if status is None:
|
|
return len(self.seqs_dict)
|
|
|
|
return len(self.get_seqs(status))
|
|
|
|
def num_unfinished_seqs(self) -> int:
|
|
return len(self.get_unfinished_seqs())
|
|
|
|
def num_finished_seqs(self) -> int:
|
|
return len(self.get_finished_seqs())
|
|
|
|
def find(self, seq_id: int) -> Sequence:
|
|
if seq_id not in self.seqs_dict:
|
|
raise ValueError(f"Sequence {seq_id} not found.")
|
|
return self.seqs_dict[seq_id]
|
|
|
|
def add(self, seq: Sequence) -> None:
|
|
if seq.seq_id in self.seqs_dict:
|
|
raise ValueError(f"Sequence {seq.seq_id} already exists.")
|
|
self.seqs_dict[seq.seq_id] = seq
|
|
|
|
def remove(self, seq_id: int) -> None:
|
|
if seq_id not in self.seqs_dict:
|
|
raise ValueError(f"Sequence {seq_id} not found.")
|
|
del self.seqs_dict[seq_id]
|
|
|
|
def is_finished(self) -> bool:
|
|
return all(seq.is_finished() for seq in self.get_seqs())
|
|
|
|
def is_prefill(self) -> bool:
|
|
# Every sequence should be in the same stage.
|
|
return self.get_seqs()[0].is_prefill()
|
|
|
|
def __repr__(self) -> str:
|
|
return (f"SequenceGroup(request_id={self.request_id}, "
|
|
f"sampling_params={self.sampling_params}, "
|
|
f"num_seqs={len(self.seqs_dict)})")
|
|
|
|
|
|
class SequenceGroupMetadata:
|
|
"""Metadata for a sequence group. Used to create `AttentionMetadata`.
|
|
|
|
Args:
|
|
request_id: The ID of the request.
|
|
is_prompt: Whether the request is at prompt stage.
|
|
seq_data: The sequence data. (Seq id -> sequence data)
|
|
sampling_params: The sampling parameters used to generate the outputs.
|
|
block_tables: The block tables. (Seq id -> list of physical block
|
|
numbers)
|
|
do_sample: True if sampling is required. Sampling is not required when
|
|
e.g., prefill is chunked, and the current iteration only computes
|
|
query tokens for prefill, we don't need sampling.
|
|
token_chunk_size: The number of tokens to be processed (per sequence).
|
|
None if chunking is not required.
|
|
lora_request: LoRA request.
|
|
computed_block_nums: The block numbers that are already computed,
|
|
used in prefix caching.
|
|
state: Internal state tied to this sequence group.
|
|
multi_modal_data: Multi modal data.
|
|
encoder_seq_data: Optional sequence data for encoder prompt
|
|
(SequenceGroup.encoder_seq). Should be None
|
|
unless you are working with an encoder/decoder
|
|
model.
|
|
cross_block_table: Optional cross-attention block table associated
|
|
with the encoder prompt
|
|
(SequenceGroup.encoder_seq). Should be None
|
|
unless you are working with an encoder/decoder
|
|
model.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
request_id: str,
|
|
is_prompt: bool,
|
|
seq_data: Dict[int, SequenceData],
|
|
sampling_params: SamplingParams,
|
|
block_tables: Dict[int, List[int]],
|
|
do_sample: bool = True,
|
|
pooling_params: Optional[PoolingParams] = None,
|
|
token_chunk_size: Optional[int] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
computed_block_nums: Optional[List[int]] = None,
|
|
state: Optional[SequenceGroupState] = None,
|
|
multi_modal_data: Optional["MultiModalData"] = None,
|
|
encoder_seq_data: Optional[SequenceData] = None,
|
|
cross_block_table: Optional[List[int]] = None,
|
|
) -> None:
|
|
self.request_id = request_id
|
|
self.is_prompt = is_prompt
|
|
self.seq_data = seq_data
|
|
self.sampling_params = sampling_params
|
|
self.block_tables = block_tables
|
|
self.pooling_params = pooling_params
|
|
self.lora_request = lora_request
|
|
self.computed_block_nums = computed_block_nums
|
|
self.multi_modal_data = multi_modal_data
|
|
self.state = SequenceGroupState() if state is None else state
|
|
self.encoder_seq_data = encoder_seq_data
|
|
self.cross_block_table = cross_block_table
|
|
self._token_chunk_size = token_chunk_size
|
|
self.do_sample = do_sample
|
|
|
|
# The number of speculative tokens adopted in this request.
|
|
# None means specuative decoding is not used.
|
|
# Zero means speculative decoding is disabled for some reasons.
|
|
# TODO: We should maintain this states out of the sequence group.
|
|
self.num_speculative_tokens = None
|
|
|
|
if self._token_chunk_size is None:
|
|
if is_prompt:
|
|
self._token_chunk_size = list(seq_data.values())[0].get_len()
|
|
else:
|
|
self._token_chunk_size = 1
|
|
|
|
@property
|
|
def lora_int_id(self) -> int:
|
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
|
|
|
@property
|
|
def token_chunk_size(self) -> int:
|
|
"""Return the number of tokens to be processed (chunk size)."""
|
|
assert self._token_chunk_size is not None
|
|
return self._token_chunk_size
|
|
|
|
|
|
class SequenceOutput:
|
|
"""The model output associated with a sequence.
|
|
|
|
Args:
|
|
parent_seq_id: The ID of the parent sequence (for forking in beam
|
|
search).
|
|
output_token: The output token ID.
|
|
logprobs: The logprobs of the output token.
|
|
(Token id -> logP(x_i+1 | x_0, ..., x_i))
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
parent_seq_id: int,
|
|
output_token: int,
|
|
logprobs: Dict[int, Logprob],
|
|
) -> None:
|
|
self.parent_seq_id = parent_seq_id
|
|
self.output_token = output_token
|
|
self.logprobs = logprobs
|
|
|
|
def __repr__(self) -> str:
|
|
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
|
f"output_token={self.output_token}, "
|
|
f"logprobs={self.logprobs})")
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, SequenceOutput):
|
|
raise NotImplementedError()
|
|
equal = (self.parent_seq_id == other.parent_seq_id
|
|
and self.output_token == other.output_token)
|
|
log_probs_equal = other.logprobs == self.logprobs
|
|
return equal and log_probs_equal
|
|
|
|
|
|
class SequenceGroupOutput(ABC):
|
|
"""The base class for model outputs associated with a sequence group."""
|
|
|
|
@abstractmethod
|
|
def __repr__(self) -> str:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def __eq__(self, other: object) -> bool:
|
|
pass
|
|
|
|
|
|
class CompletionSequenceGroupOutput(SequenceGroupOutput):
|
|
"""The model output associated with a completion sequence group."""
|
|
|
|
def __init__(
|
|
self,
|
|
samples: List[SequenceOutput],
|
|
prompt_logprobs: Optional[PromptLogprobs],
|
|
) -> None:
|
|
self.samples = samples
|
|
# Prompt logprob for each prompt query token.
|
|
self.prompt_logprobs = prompt_logprobs
|
|
|
|
def __repr__(self) -> str:
|
|
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
|
|
f"prompt_logprobs={self.prompt_logprobs})")
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, CompletionSequenceGroupOutput):
|
|
raise NotImplementedError()
|
|
return (self.samples == other.samples
|
|
and self.prompt_logprobs == other.prompt_logprobs)
|
|
|
|
|
|
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
|
|
"""The model output associated with an embedding sequence group."""
|
|
|
|
def __init__(
|
|
self,
|
|
embeddings: List[float],
|
|
) -> None:
|
|
self.embeddings = embeddings
|
|
|
|
def __repr__(self) -> str:
|
|
return (f"EmbeddingSequenceGroupOutput("
|
|
f"embeddings_shape={len(self.embeddings)})")
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, EmbeddingSequenceGroupOutput):
|
|
raise NotImplementedError()
|
|
return self.embeddings == other.embeddings
|
|
|
|
|
|
@dataclass
|
|
class SamplerOutput:
|
|
"""For each sequence group, we generate a list of SequenceOutput object,
|
|
each of which contains one possible candidate for the next token.
|
|
|
|
This data structure implements methods, so it can be used like a list, but
|
|
also has optional fields for device tensors.
|
|
"""
|
|
|
|
outputs: List[CompletionSequenceGroupOutput]
|
|
|
|
# On-device tensor containing probabilities of each token.
|
|
sampled_token_probs: Optional[torch.Tensor] = None
|
|
|
|
# On-device tensor containing the logprobs of each token.
|
|
logprobs: Optional["torch.Tensor"] = None
|
|
|
|
# On-device tensor containing the sampled token ids.
|
|
sampled_token_ids: Optional[torch.Tensor] = None
|
|
|
|
# Spec decode metrics populated by workers.
|
|
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
|
|
|
# Optional last hidden states from the model.
|
|
hidden_states: Optional[torch.Tensor] = None
|
|
|
|
def __getitem__(self, idx: int):
|
|
return self.outputs[idx]
|
|
|
|
def __setitem__(self, idx: int, value):
|
|
self.outputs[idx] = value
|
|
|
|
def __len__(self):
|
|
return len(self.outputs)
|
|
|
|
def __eq__(self, other: object):
|
|
return isinstance(other,
|
|
self.__class__) and self.outputs == other.outputs
|
|
|
|
def __repr__(self) -> str:
|
|
"""Show the shape of a tensor instead of its values to reduce noise.
|
|
"""
|
|
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
|
|
else self.sampled_token_probs.shape)
|
|
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
|
|
self.sampled_token_ids.shape)
|
|
return (
|
|
f"SamplerOutput(outputs={self.outputs}, "
|
|
f"sampled_token_probs={sampled_token_probs_repr}, "
|
|
f"sampled_token_ids={sampled_token_ids_repr}, "
|
|
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
|
|
|
|
|
@dataclass
|
|
class PoolerOutput:
|
|
"""The output from a pooling operation in the embedding model."""
|
|
outputs: List[EmbeddingSequenceGroupOutput]
|
|
|
|
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
|
|
|
def __getitem__(self, idx: int):
|
|
return self.outputs[idx]
|
|
|
|
def __setitem__(self, idx: int, value):
|
|
self.outputs[idx] = value
|
|
|
|
def __len__(self):
|
|
return len(self.outputs)
|
|
|
|
def __eq__(self, other: object):
|
|
return isinstance(other,
|
|
self.__class__) and self.outputs == other.outputs
|
|
|
|
|
|
def get_all_seq_ids(
|
|
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
|
|
"""Given a list of SequenceGroupMetadata, create a list of all
|
|
sequence ids.
|
|
"""
|
|
return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
|
|
|
|
|
|
class HiddenStates:
|
|
"""Hidden states corresponding to in-progress sequences.
|
|
Used in speculative decoding to pass hidden states from
|
|
the target model to the proposer model in the subsequent step.
|
|
|
|
seq_ids are the sequence ids of each entry of the batch
|
|
dimension of the hidden_states tensor"""
|
|
|
|
def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
hidden_states: torch.Tensor):
|
|
assert len(seq_group_metadata_list) == len(hidden_states)
|
|
self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
|
|
self.hidden_states: torch.Tensor = hidden_states
|
|
|
|
def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
hidden_states: torch.Tensor) -> None:
|
|
"""Update hidden states from target model invocation."""
|
|
assert len(seq_group_metadata_list) == len(hidden_states)
|
|
self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
|
|
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
|
|
|
|
def prune(self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
|
"""Prune to provided list of sequence ids."""
|
|
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
|
if seq_ids != self.seq_ids:
|
|
# Batch contents changed - prune removed sequences.
|
|
index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
|
|
self.hidden_states = self.hidden_states[index]
|
|
self.seq_ids = seq_ids
|
|
|
|
|
|
@dataclass
|
|
class ExecuteModelRequest:
|
|
"""The model execution request, containing CPU metadata only. The LLM
|
|
engine should create an instance of this class for each request batch."""
|
|
# The sequence group metadata list.
|
|
seq_group_metadata_list: List[SequenceGroupMetadata]
|
|
# Blocks to swap in. List of CPU -> GPU block number.
|
|
blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
|
|
# Blocks to swap out. List of GPU -> CPU block number.
|
|
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
|
|
# Blocks to copy. Source to dest block.
|
|
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
|
|
# The number of slots for lookahead decoding.
|
|
num_lookahead_slots: int = 0
|
|
# The number of requests in the running queue.
|
|
running_queue_size: int = 0
|
|
# Optional hidden states from prior step.
|
|
previous_hidden_states: Optional[HiddenStates] = None
|
|
# The number of forward steps to run.
|
|
num_steps: int = 1
|
|
|
|
def clone(
|
|
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
|
) -> "ExecuteModelRequest":
|
|
"""Clone the request with a new sequence group metadata list."""
|
|
return ExecuteModelRequest(
|
|
seq_group_metadata_list=seq_group_metadata_list,
|
|
blocks_to_swap_in=self.blocks_to_swap_in.copy(),
|
|
blocks_to_swap_out=self.blocks_to_swap_out.copy(),
|
|
blocks_to_copy=self.blocks_to_copy.copy(),
|
|
num_lookahead_slots=self.num_lookahead_slots,
|
|
running_queue_size=self.running_queue_size,
|
|
previous_hidden_states=self.previous_hidden_states,
|
|
num_steps=self.num_steps,
|
|
)
|