[V1] Add all_token_ids attribute to Request (#10135)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
073a472728
commit
42b4f46b71
@ -246,7 +246,7 @@ class Scheduler:
|
|||||||
# NOTE(woosuk): Currently, we assume that each request
|
# NOTE(woosuk): Currently, we assume that each request
|
||||||
# generates at most one token at each step.
|
# generates at most one token at each step.
|
||||||
token_id = sampled_token_ids[req_index]
|
token_id = sampled_token_ids[req_index]
|
||||||
request.output_token_ids.append(token_id)
|
request.append_output_token_ids(token_id)
|
||||||
sampled.append((request, 1))
|
sampled.append((request, 1))
|
||||||
# TODO: Update the KV cache manager for prefix caching.
|
# TODO: Update the KV cache manager for prefix caching.
|
||||||
|
|
||||||
|
@ -324,7 +324,7 @@ class LLMEngine:
|
|||||||
)
|
)
|
||||||
for req, num_tokens in sampled:
|
for req, num_tokens in sampled:
|
||||||
inputs.req_ids.append(req.request_id)
|
inputs.req_ids.append(req.request_id)
|
||||||
if len(req.output_token_ids) == num_tokens:
|
if req.num_output_tokens == num_tokens:
|
||||||
# The request is first detokenized.
|
# The request is first detokenized.
|
||||||
inputs.prompt_token_ids.append(req.prompt_token_ids)
|
inputs.prompt_token_ids.append(req.prompt_token_ids)
|
||||||
else:
|
else:
|
||||||
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, List, Optional, Union
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import RequestMetrics
|
from vllm.sequence import RequestMetrics
|
||||||
|
from vllm.v1.utils import ConstantList
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.inputs import DecoderOnlyInputs
|
from vllm.inputs import DecoderOnlyInputs
|
||||||
@ -40,17 +41,39 @@ class Request:
|
|||||||
self.prompt = inputs.get("prompt")
|
self.prompt = inputs.get("prompt")
|
||||||
self.prompt_token_ids = inputs["prompt_token_ids"]
|
self.prompt_token_ids = inputs["prompt_token_ids"]
|
||||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||||
self.output_token_ids: List[int] = []
|
self._output_token_ids: List[int] = []
|
||||||
|
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
|
||||||
self.output_text = ""
|
self.output_text = ""
|
||||||
self.num_computed_tokens = 0
|
self.num_computed_tokens = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_token_ids(self) -> ConstantList[int]:
|
||||||
|
# Prevent directly appending to the output_token_ids since
|
||||||
|
# all_token_ids should also be updated simultaneously.
|
||||||
|
return ConstantList(self._output_token_ids)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_token_ids(self) -> ConstantList[int]:
|
||||||
|
# Prevent directly appending to the all_token_ids since
|
||||||
|
# output_token_ids should also be updated simultaneously
|
||||||
|
return ConstantList(self._all_token_ids)
|
||||||
|
|
||||||
|
def append_output_token_ids(
|
||||||
|
self,
|
||||||
|
token_ids: Union[int, List[int]],
|
||||||
|
) -> None:
|
||||||
|
if isinstance(token_ids, int):
|
||||||
|
token_ids = [token_ids]
|
||||||
|
self._output_token_ids.extend(token_ids)
|
||||||
|
self._all_token_ids.extend(token_ids)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_tokens(self) -> int:
|
def num_tokens(self) -> int:
|
||||||
return self.num_prompt_tokens + len(self.output_token_ids)
|
return len(self._all_token_ids)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_output_tokens(self) -> int:
|
def num_output_tokens(self) -> int:
|
||||||
return len(self.output_token_ids)
|
return len(self._output_token_ids)
|
||||||
|
|
||||||
def is_finished(self) -> bool:
|
def is_finished(self) -> bool:
|
||||||
return RequestStatus.is_finished(self.status)
|
return RequestStatus.is_finished(self.status)
|
||||||
|
64
vllm/v1/utils.py
Normal file
64
vllm/v1/utils.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from typing import Generic, List, TypeVar, overload
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class ConstantList(Generic[T]):
|
||||||
|
|
||||||
|
def __init__(self, x: List[T]) -> None:
|
||||||
|
self._x = x
|
||||||
|
|
||||||
|
def append(self, item):
|
||||||
|
raise Exception("Cannot append to a constant list")
|
||||||
|
|
||||||
|
def extend(self, item):
|
||||||
|
raise Exception("Cannot extend a constant list")
|
||||||
|
|
||||||
|
def insert(self, item):
|
||||||
|
raise Exception("Cannot insert into a constant list")
|
||||||
|
|
||||||
|
def pop(self, item):
|
||||||
|
raise Exception("Cannot pop from a constant list")
|
||||||
|
|
||||||
|
def remove(self, item):
|
||||||
|
raise Exception("Cannot remove from a constant list")
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
raise Exception("Cannot clear a constant list")
|
||||||
|
|
||||||
|
def index(self, item):
|
||||||
|
return self._x.index(item)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, item) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, s: slice, /) -> List[T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self._x[item]
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __setitem__(self, item, value):
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __setitem__(self, s: slice, value, /):
|
||||||
|
...
|
||||||
|
|
||||||
|
def __setitem__(self, item, value):
|
||||||
|
raise Exception("Cannot set item in a constant list")
|
||||||
|
|
||||||
|
def __delitem__(self, item):
|
||||||
|
raise Exception("Cannot delete item from a constant list")
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._x)
|
||||||
|
|
||||||
|
def __contains__(self, item):
|
||||||
|
return item in self._x
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._x)
|
Loading…
x
Reference in New Issue
Block a user