312 lines
12 KiB
Python
312 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import asyncio
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
from vllm.outputs import RequestOutput
|
|
from vllm.sampling_params import RequestOutputKind
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
|
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
|
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
|
from vllm.v1.engine.logprobs import LogprobsProcessor
|
|
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
|
|
RequestStateStats)
|
|
|
|
|
|
@dataclass
|
|
class OutputProcessorOutput:
|
|
|
|
request_outputs: List[RequestOutput]
|
|
reqs_to_abort: List[str]
|
|
|
|
|
|
class RequestState:
|
|
|
|
def __init__(
|
|
self,
|
|
request_id: str,
|
|
lora_name: Optional[str],
|
|
output_kind: RequestOutputKind,
|
|
prompt: Optional[str],
|
|
prompt_token_ids: List[int],
|
|
logprobs_processor: LogprobsProcessor,
|
|
detokenizer: IncrementalDetokenizer,
|
|
arrival_time: float,
|
|
queue: Optional[asyncio.Queue[RequestOutput]],
|
|
log_stats: bool,
|
|
):
|
|
self.request_id = request_id
|
|
self.lora_name = lora_name
|
|
self.output_kind = output_kind
|
|
self.prompt = prompt
|
|
self.prompt_token_ids = prompt_token_ids
|
|
self.prompt_len = len(prompt_token_ids)
|
|
self.logprobs_processor = logprobs_processor
|
|
self.detokenizer = detokenizer
|
|
self.is_prefilling = True
|
|
self.queue = queue
|
|
|
|
self.stats = RequestStateStats(
|
|
arrival_time=arrival_time) if log_stats else None
|
|
|
|
@classmethod
|
|
def from_new_request(
|
|
cls,
|
|
tokenizer: AnyTokenizer,
|
|
request: EngineCoreRequest,
|
|
queue: Optional[asyncio.Queue[RequestOutput]],
|
|
log_stats: bool,
|
|
) -> "RequestState":
|
|
return cls(
|
|
request_id=request.request_id,
|
|
lora_name=(request.lora_request.name
|
|
if request.lora_request is not None else None),
|
|
output_kind=request.sampling_params.output_kind,
|
|
prompt=request.prompt,
|
|
prompt_token_ids=request.prompt_token_ids,
|
|
logprobs_processor=LogprobsProcessor.from_new_request(
|
|
tokenizer=tokenizer,
|
|
request=request,
|
|
),
|
|
detokenizer=IncrementalDetokenizer.from_new_request(
|
|
tokenizer=tokenizer,
|
|
request=request,
|
|
),
|
|
arrival_time=request.arrival_time,
|
|
queue=queue,
|
|
log_stats=log_stats,
|
|
)
|
|
|
|
|
|
class OutputProcessor:
|
|
"""Process EngineCoreOutputs into RequestOutputs."""
|
|
|
|
def __init__(
|
|
self,
|
|
tokenizer: BaseTokenizerGroup,
|
|
log_stats: bool,
|
|
):
|
|
self.log_stats = log_stats
|
|
self.tokenizer = tokenizer
|
|
self.request_states: Dict[str, RequestState] = {}
|
|
self.lora_states = LoRARequestStates()
|
|
|
|
def is_request_active(self, request_id: str) -> bool:
|
|
return request_id in self.request_states
|
|
|
|
def get_num_unfinished_requests(self):
|
|
return len(self.request_states)
|
|
|
|
def has_unfinished_requests(self) -> bool:
|
|
return len(self.request_states) > 0
|
|
|
|
def abort_requests(
|
|
self,
|
|
request_ids: List[str],
|
|
) -> None:
|
|
for request_id in request_ids:
|
|
req_state = self.request_states.pop(request_id, None)
|
|
if req_state is not None:
|
|
self.lora_states.abort_request(req_state)
|
|
|
|
def add_request(
|
|
self,
|
|
request: EngineCoreRequest,
|
|
queue: Optional[asyncio.Queue[RequestOutput]] = None,
|
|
) -> None:
|
|
request_id = request.request_id
|
|
if request_id in self.request_states:
|
|
raise ValueError(f"Request id {request_id} already running.")
|
|
|
|
req_state = RequestState.from_new_request(
|
|
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
|
|
request=request,
|
|
queue=queue,
|
|
log_stats=self.log_stats)
|
|
self.request_states[request_id] = req_state
|
|
self.lora_states.add_request(req_state)
|
|
|
|
def process_outputs(
|
|
self,
|
|
engine_core_outputs: List[EngineCoreOutput],
|
|
engine_core_timestamp: Optional[float] = None,
|
|
iteration_stats: Optional[IterationStats] = None,
|
|
) -> OutputProcessorOutput:
|
|
"""
|
|
Process the EngineCoreOutputs:
|
|
1) Compute stats for logging
|
|
2) Detokenize
|
|
3) Create and handle RequestOutput objects:
|
|
* If there is a queue (for usage with AsyncLLM),
|
|
put the RequestOutput objects into the queue for
|
|
handling by the per-request generate() tasks.
|
|
|
|
* If there is no queue (for usage with LLMEngine),
|
|
return a list of RequestOutput objects.
|
|
|
|
****************** NOTE FOR DEVELOPERS ******************
|
|
|
|
VLLM V1 minimizes the number of python loops over the full
|
|
batch to ensure system overheads are minimized. This is the
|
|
only function that should loop over EngineCoreOutputs.
|
|
|
|
If you need to touch every element of the batch, do it from
|
|
within the loop below.
|
|
|
|
**********************************************************
|
|
"""
|
|
|
|
request_outputs: List[RequestOutput] = []
|
|
reqs_to_abort: List[str] = []
|
|
for engine_core_output in engine_core_outputs:
|
|
req_id = engine_core_output.request_id
|
|
req_state = self.request_states.get(req_id)
|
|
if req_state is None:
|
|
# Ignore output for already-aborted request.
|
|
continue
|
|
|
|
# 1) Compute stats for this iteration.
|
|
self._update_stats_from_output(req_state, engine_core_output,
|
|
engine_core_timestamp,
|
|
iteration_stats)
|
|
|
|
new_token_ids = engine_core_output.new_token_ids
|
|
finish_reason = engine_core_output.finish_reason
|
|
stop_reason = engine_core_output.stop_reason
|
|
|
|
# TODO(andy): prompt logprobs + chunked prefill can
|
|
# result in engine core returning an output for a
|
|
# partial prefill (in order to send back partial
|
|
# prompt logprobs.) This breaks the invariant that
|
|
# process_outputs is only operating on engine core
|
|
# outputs associated with non-partial completions.
|
|
# Currently this is handled by having `is_prefilling`
|
|
# check for new decoded tokens, indicating that
|
|
# the completion is not partial.
|
|
#
|
|
# Follow up will aggregate partial prompt logprobs
|
|
# in the EngineCore.
|
|
req_state.is_prefilling = not new_token_ids
|
|
|
|
# 2) Detokenize the token ids into text and check for stop
|
|
# strings.
|
|
stop_string = req_state.detokenizer.update(new_token_ids)
|
|
if stop_string and finish_reason != FinishReason.STOP:
|
|
finish_reason = FinishReason.STOP
|
|
stop_reason = stop_string
|
|
|
|
# 3) Compute sample and prompt logprobs for request,
|
|
# if required.
|
|
req_state.logprobs_processor.update_from_output(engine_core_output)
|
|
|
|
# 4) Create and handle RequestOutput objects.
|
|
if request_output := self._make_request_output(
|
|
req_state, new_token_ids, finish_reason, stop_reason):
|
|
if req_state.queue is not None:
|
|
# AsyncLLM: put into queue for handling by generate().
|
|
req_state.queue.put_nowait(request_output)
|
|
else:
|
|
# LLMEngine: return list of RequestOutputs.
|
|
request_outputs.append(request_output)
|
|
|
|
# Free completed requests.
|
|
if request_output.finished:
|
|
self.request_states.pop(req_id)
|
|
if not engine_core_output.finished:
|
|
# If req not finished in EngineCore, but Detokenizer
|
|
# detected stop string, abort needed in EngineCore.
|
|
reqs_to_abort.append(req_id)
|
|
|
|
# Track per-request stats
|
|
self._update_stats_from_finished(req_state, request_output,
|
|
finish_reason,
|
|
iteration_stats)
|
|
|
|
self.lora_states.update_iteration_stats(iteration_stats)
|
|
|
|
return OutputProcessorOutput(
|
|
request_outputs=request_outputs,
|
|
reqs_to_abort=reqs_to_abort,
|
|
)
|
|
|
|
def _update_stats_from_output(self, req_state: RequestState,
|
|
engine_core_output: EngineCoreOutput,
|
|
engine_core_timestamp: Optional[float],
|
|
iteration_stats: Optional[IterationStats]):
|
|
if iteration_stats is None:
|
|
return
|
|
|
|
lora_stats = self.lora_states.get_stats(req_state)
|
|
|
|
assert engine_core_timestamp is not None
|
|
assert req_state.stats is not None
|
|
iteration_stats.update_from_output(engine_core_output,
|
|
engine_core_timestamp,
|
|
req_state.is_prefilling,
|
|
req_state.prompt_len,
|
|
req_state.stats, lora_stats)
|
|
|
|
def _update_stats_from_finished(self, req_state: RequestState,
|
|
request_output: RequestOutput,
|
|
finish_reason: Optional[FinishReason],
|
|
iteration_stats: Optional[IterationStats]):
|
|
if iteration_stats is None:
|
|
return
|
|
|
|
assert finish_reason is not None
|
|
assert req_state.stats is not None
|
|
iteration_stats.update_from_finished_request(finish_reason,
|
|
request_output,
|
|
req_state.stats)
|
|
self.lora_states.finish_request(req_state)
|
|
|
|
@staticmethod
|
|
def _make_request_output(
|
|
request_state: RequestState,
|
|
new_token_ids: List[int],
|
|
finish_reason: Optional[FinishReason],
|
|
stop_reason: Union[int, str, None],
|
|
) -> Optional[RequestOutput]:
|
|
|
|
finished = finish_reason is not None
|
|
output_kind = request_state.output_kind
|
|
# In follow up, we will switch to invariant where EngineCore
|
|
# does not stream partial prefills.
|
|
if not finished and (request_state.is_prefilling
|
|
or output_kind == RequestOutputKind.FINAL_ONLY):
|
|
# Only the final output is required in FINAL_ONLY mode.
|
|
return None
|
|
|
|
detokenizer = request_state.detokenizer
|
|
logprobs_processor = request_state.logprobs_processor
|
|
|
|
delta = output_kind == RequestOutputKind.DELTA
|
|
logprobs = logprobs_processor.logprobs
|
|
if delta:
|
|
if logprobs:
|
|
logprobs = logprobs[-len(new_token_ids):]
|
|
# Side effect: logprobs processor forgets prompt logprobs
|
|
prompt_logprobs = logprobs_processor.pop_prompt_logprobs()
|
|
else:
|
|
prompt_logprobs = logprobs_processor.prompt_logprobs
|
|
|
|
request_output = RequestOutput.new(
|
|
request_id=request_state.request_id,
|
|
prompt=request_state.prompt,
|
|
prompt_token_ids=request_state.prompt_token_ids,
|
|
text=detokenizer.get_next_output_text(finished, delta),
|
|
token_ids=new_token_ids if delta else detokenizer.output_token_ids,
|
|
logprobs=logprobs,
|
|
prompt_logprobs=prompt_logprobs,
|
|
cumulative_logprob=logprobs_processor.cumulative_logprob,
|
|
finished=finished,
|
|
)
|
|
if finished:
|
|
completion_output = request_output.outputs[0]
|
|
completion_output.finish_reason = str(finish_reason)
|
|
completion_output.stop_reason = stop_reason
|
|
|
|
return request_output
|