[V1] Add all_token_ids attribute to Request (#10135)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2024-11-07 17:08:24 -08:00 committed by GitHub
parent 073a472728
commit 42b4f46b71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 92 additions and 5 deletions

View File

@ -246,7 +246,7 @@ class Scheduler:
# NOTE(woosuk): Currently, we assume that each request
# generates at most one token at each step.
token_id = sampled_token_ids[req_index]
request.output_token_ids.append(token_id)
request.append_output_token_ids(token_id)
sampled.append((request, 1))
# TODO: Update the KV cache manager for prefix caching.

View File

@ -324,7 +324,7 @@ class LLMEngine:
)
for req, num_tokens in sampled:
inputs.req_ids.append(req.request_id)
if len(req.output_token_ids) == num_tokens:
if req.num_output_tokens == num_tokens:
# The request is first detokenized.
inputs.prompt_token_ids.append(req.prompt_token_ids)
else:

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, List, Optional, Union
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.inputs import DecoderOnlyInputs
@ -40,17 +41,39 @@ class Request:
self.prompt = inputs.get("prompt")
self.prompt_token_ids = inputs["prompt_token_ids"]
self.num_prompt_tokens = len(self.prompt_token_ids)
self.output_token_ids: List[int] = []
self._output_token_ids: List[int] = []
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.output_text = ""
self.num_computed_tokens = 0
@property
def output_token_ids(self) -> ConstantList[int]:
# Prevent directly appending to the output_token_ids since
# all_token_ids should also be updated simultaneously.
return ConstantList(self._output_token_ids)
@property
def all_token_ids(self) -> ConstantList[int]:
# Prevent directly appending to the all_token_ids since
# output_token_ids should also be updated simultaneously
return ConstantList(self._all_token_ids)
def append_output_token_ids(
self,
token_ids: Union[int, List[int]],
) -> None:
if isinstance(token_ids, int):
token_ids = [token_ids]
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
return len(self._all_token_ids)
@property
def num_output_tokens(self) -> int:
return len(self.output_token_ids)
return len(self._output_token_ids)
def is_finished(self) -> bool:
return RequestStatus.is_finished(self.status)

64
vllm/v1/utils.py Normal file
View File

@ -0,0 +1,64 @@
from typing import Generic, List, TypeVar, overload
T = TypeVar("T")
class ConstantList(Generic[T]):
def __init__(self, x: List[T]) -> None:
self._x = x
def append(self, item):
raise Exception("Cannot append to a constant list")
def extend(self, item):
raise Exception("Cannot extend a constant list")
def insert(self, item):
raise Exception("Cannot insert into a constant list")
def pop(self, item):
raise Exception("Cannot pop from a constant list")
def remove(self, item):
raise Exception("Cannot remove from a constant list")
def clear(self):
raise Exception("Cannot clear a constant list")
def index(self, item):
return self._x.index(item)
@overload
def __getitem__(self, item) -> T:
...
@overload
def __getitem__(self, s: slice, /) -> List[T]:
...
def __getitem__(self, item):
return self._x[item]
@overload
def __setitem__(self, item, value):
...
@overload
def __setitem__(self, s: slice, value, /):
...
def __setitem__(self, item, value):
raise Exception("Cannot set item in a constant list")
def __delitem__(self, item):
raise Exception("Cannot delete item from a constant list")
def __iter__(self):
return iter(self._x)
def __contains__(self, item):
return item in self._x
def __len__(self):
return len(self._x)