[V1] Add all_token_ids attribute to Request (#10135)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
073a472728
commit
42b4f46b71
@ -246,7 +246,7 @@ class Scheduler:
|
||||
# NOTE(woosuk): Currently, we assume that each request
|
||||
# generates at most one token at each step.
|
||||
token_id = sampled_token_ids[req_index]
|
||||
request.output_token_ids.append(token_id)
|
||||
request.append_output_token_ids(token_id)
|
||||
sampled.append((request, 1))
|
||||
# TODO: Update the KV cache manager for prefix caching.
|
||||
|
||||
|
@ -324,7 +324,7 @@ class LLMEngine:
|
||||
)
|
||||
for req, num_tokens in sampled:
|
||||
inputs.req_ids.append(req.request_id)
|
||||
if len(req.output_token_ids) == num_tokens:
|
||||
if req.num_output_tokens == num_tokens:
|
||||
# The request is first detokenized.
|
||||
inputs.prompt_token_ids.append(req.prompt_token_ids)
|
||||
else:
|
||||
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import RequestMetrics
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.inputs import DecoderOnlyInputs
|
||||
@ -40,17 +41,39 @@ class Request:
|
||||
self.prompt = inputs.get("prompt")
|
||||
self.prompt_token_ids = inputs["prompt_token_ids"]
|
||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||
self.output_token_ids: List[int] = []
|
||||
self._output_token_ids: List[int] = []
|
||||
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
|
||||
self.output_text = ""
|
||||
self.num_computed_tokens = 0
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> ConstantList[int]:
|
||||
# Prevent directly appending to the output_token_ids since
|
||||
# all_token_ids should also be updated simultaneously.
|
||||
return ConstantList(self._output_token_ids)
|
||||
|
||||
@property
|
||||
def all_token_ids(self) -> ConstantList[int]:
|
||||
# Prevent directly appending to the all_token_ids since
|
||||
# output_token_ids should also be updated simultaneously
|
||||
return ConstantList(self._all_token_ids)
|
||||
|
||||
def append_output_token_ids(
|
||||
self,
|
||||
token_ids: Union[int, List[int]],
|
||||
) -> None:
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
self._output_token_ids.extend(token_ids)
|
||||
self._all_token_ids.extend(token_ids)
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return self.num_prompt_tokens + len(self.output_token_ids)
|
||||
return len(self._all_token_ids)
|
||||
|
||||
@property
|
||||
def num_output_tokens(self) -> int:
|
||||
return len(self.output_token_ids)
|
||||
return len(self._output_token_ids)
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return RequestStatus.is_finished(self.status)
|
||||
|
64
vllm/v1/utils.py
Normal file
64
vllm/v1/utils.py
Normal file
@ -0,0 +1,64 @@
|
||||
from typing import Generic, List, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ConstantList(Generic[T]):
|
||||
|
||||
def __init__(self, x: List[T]) -> None:
|
||||
self._x = x
|
||||
|
||||
def append(self, item):
|
||||
raise Exception("Cannot append to a constant list")
|
||||
|
||||
def extend(self, item):
|
||||
raise Exception("Cannot extend a constant list")
|
||||
|
||||
def insert(self, item):
|
||||
raise Exception("Cannot insert into a constant list")
|
||||
|
||||
def pop(self, item):
|
||||
raise Exception("Cannot pop from a constant list")
|
||||
|
||||
def remove(self, item):
|
||||
raise Exception("Cannot remove from a constant list")
|
||||
|
||||
def clear(self):
|
||||
raise Exception("Cannot clear a constant list")
|
||||
|
||||
def index(self, item):
|
||||
return self._x.index(item)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item) -> T:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice, /) -> List[T]:
|
||||
...
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._x[item]
|
||||
|
||||
@overload
|
||||
def __setitem__(self, item, value):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __setitem__(self, s: slice, value, /):
|
||||
...
|
||||
|
||||
def __setitem__(self, item, value):
|
||||
raise Exception("Cannot set item in a constant list")
|
||||
|
||||
def __delitem__(self, item):
|
||||
raise Exception("Cannot delete item from a constant list")
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._x)
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._x
|
||||
|
||||
def __len__(self):
|
||||
return len(self._x)
|
Loading…
x
Reference in New Issue
Block a user