2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-08-06 22:21:41 -07:00
|
|
|
import asyncio
|
2024-09-04 15:18:13 -05:00
|
|
|
import json
|
2024-03-25 23:59:47 +09:00
|
|
|
import time
|
2024-09-04 15:18:13 -05:00
|
|
|
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
|
|
|
|
Optional)
|
2024-05-30 11:52:14 +02:00
|
|
|
from typing import Sequence as GenericSequence
|
2024-07-16 12:18:09 +00:00
|
|
|
from typing import Union
|
2024-03-25 23:59:47 +09:00
|
|
|
|
2024-01-17 05:33:14 +00:00
|
|
|
from fastapi import Request
|
2024-03-25 23:59:47 +09:00
|
|
|
|
2024-07-03 11:34:00 +08:00
|
|
|
from vllm.config import ModelConfig
|
2024-09-18 09:56:58 -04:00
|
|
|
from vllm.engine.protocol import EngineClient
|
2024-11-16 13:35:40 +08:00
|
|
|
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
|
|
|
ConversationMessage)
|
2024-07-23 01:13:53 +08:00
|
|
|
from vllm.entrypoints.logger import RequestLogger
|
2024-01-17 05:33:14 +00:00
|
|
|
from vllm.entrypoints.openai.protocol import (
|
2024-07-16 12:18:09 +00:00
|
|
|
ChatCompletionLogProb, ChatCompletionLogProbs,
|
|
|
|
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
|
2024-06-04 01:25:29 +02:00
|
|
|
ChatCompletionRequest, ChatCompletionResponse,
|
2024-01-17 05:33:14 +00:00
|
|
|
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
2024-09-04 15:18:13 -05:00
|
|
|
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
2024-11-12 08:42:28 -08:00
|
|
|
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
|
|
|
|
RequestResponseMetadata, ToolCall, UsageInfo)
|
2025-01-29 11:38:08 +08:00
|
|
|
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
|
|
|
|
ReasoningParserManager)
|
2024-12-31 18:21:51 -08:00
|
|
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
|
|
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
2024-10-04 10:36:39 +08:00
|
|
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
2025-02-12 11:29:56 -05:00
|
|
|
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
|
|
|
MistralToolCall)
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.logger import init_logger
|
2024-09-04 15:18:13 -05:00
|
|
|
from vllm.outputs import CompletionOutput, RequestOutput
|
2024-10-05 23:39:03 -07:00
|
|
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
2024-05-30 11:52:14 +02:00
|
|
|
from vllm.sequence import Logprob
|
2024-09-11 00:49:11 +08:00
|
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
2025-02-12 11:29:56 -05:00
|
|
|
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
|
|
|
|
truncate_tool_call_ids)
|
2024-01-17 05:33:14 +00:00
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAIServingChat(OpenAIServing):
|
|
|
|
|
2024-11-16 13:35:40 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
engine_client: EngineClient,
|
|
|
|
model_config: ModelConfig,
|
2024-12-31 18:21:51 -08:00
|
|
|
models: OpenAIServingModels,
|
2024-11-16 13:35:40 +08:00
|
|
|
response_role: str,
|
|
|
|
*,
|
|
|
|
request_logger: Optional[RequestLogger],
|
|
|
|
chat_template: Optional[str],
|
|
|
|
chat_template_content_format: ChatTemplateContentFormatOption,
|
|
|
|
return_tokens_as_token_ids: bool = False,
|
2025-01-29 11:38:08 +08:00
|
|
|
enable_reasoning: bool = False,
|
|
|
|
reasoning_parser: Optional[str] = None,
|
2024-11-16 13:35:40 +08:00
|
|
|
enable_auto_tools: bool = False,
|
|
|
|
tool_parser: Optional[str] = None,
|
|
|
|
enable_prompt_tokens_details: bool = False,
|
|
|
|
) -> None:
|
2024-09-18 09:56:58 -04:00
|
|
|
super().__init__(engine_client=engine_client,
|
2024-05-09 13:48:33 +08:00
|
|
|
model_config=model_config,
|
2024-12-31 18:21:51 -08:00
|
|
|
models=models,
|
2024-07-24 18:51:00 -07:00
|
|
|
request_logger=request_logger,
|
|
|
|
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
2024-05-03 20:04:14 +02:00
|
|
|
|
2024-01-17 05:33:14 +00:00
|
|
|
self.response_role = response_role
|
2024-11-16 13:35:40 +08:00
|
|
|
self.chat_template = chat_template
|
|
|
|
self.chat_template_content_format: Final = chat_template_content_format
|
2024-04-27 13:08:24 +08:00
|
|
|
|
2024-09-04 15:18:13 -05:00
|
|
|
# set up tool use
|
|
|
|
self.enable_auto_tools: bool = enable_auto_tools
|
|
|
|
if self.enable_auto_tools:
|
|
|
|
logger.info(
|
|
|
|
"\"auto\" tool choice has been enabled please note that while"
|
|
|
|
" the parallel_tool_calls client option is preset for "
|
|
|
|
"compatibility reasons, it will be ignored.")
|
|
|
|
|
2025-01-29 11:38:08 +08:00
|
|
|
self.enable_reasoning: bool = enable_reasoning
|
|
|
|
self.reasoning_parser: Optional[Callable[[AnyTokenizer],
|
|
|
|
ReasoningParser]] = None
|
|
|
|
if self.enable_reasoning:
|
|
|
|
try:
|
|
|
|
self.reasoning_parser = (
|
|
|
|
ReasoningParserManager.get_reasoning_parser(
|
|
|
|
reasoning_parser))
|
|
|
|
except Exception as e:
|
|
|
|
raise TypeError("Error: --enable-reasoning requires "
|
|
|
|
f"reasoning_parser:'{reasoning_parser}' "
|
|
|
|
"which has not been registered") from e
|
2024-09-04 15:18:13 -05:00
|
|
|
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
|
|
|
if self.enable_auto_tools:
|
2024-10-04 10:36:39 +08:00
|
|
|
try:
|
2024-11-13 20:14:34 -08:00
|
|
|
if (tool_parser == "pythonic" and
|
|
|
|
model_config.model.startswith("meta-llama/Llama-3.2")):
|
|
|
|
logger.warning(
|
|
|
|
"Llama3.2 models may struggle to emit valid pythonic"
|
|
|
|
" tool calls")
|
2024-10-04 10:36:39 +08:00
|
|
|
self.tool_parser = ToolParserManager.get_tool_parser(
|
|
|
|
tool_parser)
|
|
|
|
except Exception as e:
|
2024-09-04 15:18:13 -05:00
|
|
|
raise TypeError("Error: --enable-auto-tool-choice requires "
|
2024-10-04 10:36:39 +08:00
|
|
|
f"tool_parser:'{tool_parser}' which has not "
|
|
|
|
"been registered") from e
|
2024-09-04 15:18:13 -05:00
|
|
|
|
2024-11-12 08:42:28 -08:00
|
|
|
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
2024-12-19 18:50:38 +08:00
|
|
|
diff_sampling_param = self.model_config.get_diff_sampling_param()
|
|
|
|
if diff_sampling_param:
|
|
|
|
logger.info("Overwriting default chat sampling param with: %s",
|
|
|
|
diff_sampling_param)
|
2024-11-12 08:42:28 -08:00
|
|
|
|
2024-01-17 05:33:14 +00:00
|
|
|
async def create_chat_completion(
|
2024-05-15 19:13:36 -04:00
|
|
|
self,
|
|
|
|
request: ChatCompletionRequest,
|
2024-08-21 14:28:21 +08:00
|
|
|
raw_request: Optional[Request] = None,
|
|
|
|
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
|
|
|
|
ErrorResponse]:
|
2024-11-01 16:13:35 +08:00
|
|
|
"""
|
|
|
|
Chat Completion API similar to OpenAI's API.
|
2024-01-17 05:33:14 +00:00
|
|
|
|
2024-03-10 19:49:14 -07:00
|
|
|
See https://platform.openai.com/docs/api-reference/chat/create
|
|
|
|
for the API specification. This API mimics the OpenAI
|
2024-11-01 16:13:35 +08:00
|
|
|
Chat Completion API.
|
2024-01-17 05:33:14 +00:00
|
|
|
"""
|
|
|
|
error_check_ret = await self._check_model(request)
|
|
|
|
if error_check_ret is not None:
|
2024-09-04 15:18:13 -05:00
|
|
|
logger.error("Error with model %s", error_check_ret)
|
2024-01-17 05:33:14 +00:00
|
|
|
return error_check_ret
|
|
|
|
|
2024-09-18 09:56:58 -04:00
|
|
|
# If the engine is dead, raise the engine's DEAD_ERROR.
|
|
|
|
# This is required for the streaming case, where we return a
|
|
|
|
# success status before we actually start generating text :).
|
|
|
|
if self.engine_client.errored:
|
|
|
|
raise self.engine_client.dead_error
|
|
|
|
|
2024-01-17 05:33:14 +00:00
|
|
|
try:
|
2024-07-23 01:13:53 +08:00
|
|
|
(
|
|
|
|
lora_request,
|
|
|
|
prompt_adapter_request,
|
|
|
|
) = self._maybe_get_adapters(request)
|
|
|
|
|
2024-12-31 18:21:51 -08:00
|
|
|
model_name = self.models.model_name(lora_request)
|
2024-12-12 01:25:16 -08:00
|
|
|
|
2024-09-18 09:56:58 -04:00
|
|
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
2024-11-16 13:35:40 +08:00
|
|
|
|
2024-11-01 16:13:35 +08:00
|
|
|
tool_parser = self.tool_parser
|
|
|
|
|
|
|
|
# validation for OpenAI tools
|
|
|
|
# tool_choice = "required" is not supported
|
|
|
|
if request.tool_choice == "required":
|
|
|
|
return self.create_error_response(
|
|
|
|
"tool_choice = \"required\" is not supported!")
|
|
|
|
|
2024-11-14 05:48:16 +01:00
|
|
|
if isinstance(tokenizer, MistralTokenizer):
|
2025-02-12 11:29:56 -05:00
|
|
|
# because of issues with pydantic we need to potentially
|
|
|
|
# re-serialize the tool_calls field of the request
|
|
|
|
# for more info: see comment in `maybe_serialize_tool_calls`
|
2024-11-15 01:42:49 +01:00
|
|
|
maybe_serialize_tool_calls(request)
|
2025-02-12 11:29:56 -05:00
|
|
|
truncate_tool_call_ids(request)
|
2024-11-14 05:48:16 +01:00
|
|
|
|
2024-11-01 16:13:35 +08:00
|
|
|
if (request.tool_choice == "auto" and
|
|
|
|
not (self.enable_auto_tools and tool_parser is not None)
|
|
|
|
and not isinstance(tokenizer, MistralTokenizer)):
|
|
|
|
# for hf tokenizers, "auto" tools requires
|
|
|
|
# --enable-auto-tool-choice and --tool-call-parser
|
|
|
|
return self.create_error_response(
|
|
|
|
"\"auto\" tool choice requires "
|
|
|
|
"--enable-auto-tool-choice and --tool-call-parser to be set"
|
|
|
|
)
|
2024-04-27 13:08:24 +08:00
|
|
|
|
2024-07-02 09:01:57 +03:00
|
|
|
tool_dicts = None if request.tools is None else [
|
|
|
|
tool.model_dump() for tool in request.tools
|
|
|
|
]
|
|
|
|
|
2024-11-01 16:13:35 +08:00
|
|
|
(
|
|
|
|
conversation,
|
|
|
|
request_prompts,
|
|
|
|
engine_prompts,
|
|
|
|
) = await self._preprocess_chat(
|
|
|
|
request,
|
|
|
|
tokenizer,
|
|
|
|
request.messages,
|
|
|
|
chat_template=request.chat_template or self.chat_template,
|
2024-11-16 13:35:40 +08:00
|
|
|
chat_template_content_format=self.chat_template_content_format,
|
2024-11-01 16:13:35 +08:00
|
|
|
add_generation_prompt=request.add_generation_prompt,
|
|
|
|
continue_final_message=request.continue_final_message,
|
|
|
|
tool_dicts=tool_dicts,
|
|
|
|
documents=request.documents,
|
|
|
|
chat_template_kwargs=request.chat_template_kwargs,
|
|
|
|
tool_parser=tool_parser,
|
|
|
|
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
|
|
|
add_special_tokens=request.add_special_tokens,
|
|
|
|
)
|
|
|
|
except ValueError as e:
|
|
|
|
logger.exception("Error in preprocessing prompt inputs")
|
2024-06-07 11:23:32 -07:00
|
|
|
return self.create_error_response(str(e))
|
|
|
|
|
2024-12-09 22:46:29 -07:00
|
|
|
request_id = "chatcmpl-" \
|
|
|
|
f"{self._base_request_id(raw_request, request.request_id)}"
|
2024-09-25 00:49:26 -07:00
|
|
|
|
|
|
|
request_metadata = RequestResponseMetadata(request_id=request_id)
|
|
|
|
if raw_request:
|
|
|
|
raw_request.state.request_metadata = request_metadata
|
|
|
|
|
2024-11-01 16:13:35 +08:00
|
|
|
# Schedule the request and get the result generator.
|
|
|
|
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
2024-01-17 05:33:14 +00:00
|
|
|
try:
|
2024-11-01 16:13:35 +08:00
|
|
|
for i, engine_prompt in enumerate(engine_prompts):
|
|
|
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
|
|
|
default_max_tokens = self.max_model_len - len(
|
|
|
|
engine_prompt["prompt_token_ids"])
|
2024-12-19 18:50:38 +08:00
|
|
|
# Build default sampling params
|
|
|
|
default_sampling_params = (
|
|
|
|
self.model_config.get_diff_sampling_param())
|
2024-11-01 16:13:35 +08:00
|
|
|
if request.use_beam_search:
|
|
|
|
sampling_params = request.to_beam_search_params(
|
2024-12-19 18:50:38 +08:00
|
|
|
default_max_tokens, default_sampling_params)
|
2024-11-01 16:13:35 +08:00
|
|
|
else:
|
|
|
|
sampling_params = request.to_sampling_params(
|
2024-12-14 09:46:42 -07:00
|
|
|
default_max_tokens,
|
2024-12-19 18:50:38 +08:00
|
|
|
self.model_config.logits_processor_pattern,
|
|
|
|
default_sampling_params)
|
2024-11-01 16:13:35 +08:00
|
|
|
|
|
|
|
self._log_inputs(request_id,
|
|
|
|
request_prompts[i],
|
|
|
|
params=sampling_params,
|
|
|
|
lora_request=lora_request,
|
|
|
|
prompt_adapter_request=prompt_adapter_request)
|
|
|
|
|
|
|
|
trace_headers = (None if raw_request is None else await
|
|
|
|
self._get_trace_headers(raw_request.headers))
|
|
|
|
|
|
|
|
if isinstance(sampling_params, BeamSearchParams):
|
|
|
|
generator = self.engine_client.beam_search(
|
|
|
|
prompt=engine_prompt,
|
|
|
|
request_id=request_id,
|
|
|
|
params=sampling_params,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
generator = self.engine_client.generate(
|
|
|
|
engine_prompt,
|
|
|
|
sampling_params,
|
|
|
|
request_id,
|
|
|
|
lora_request=lora_request,
|
|
|
|
trace_headers=trace_headers,
|
|
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
|
|
priority=request.priority,
|
|
|
|
)
|
|
|
|
|
|
|
|
generators.append(generator)
|
2024-01-17 05:33:14 +00:00
|
|
|
except ValueError as e:
|
2024-07-23 01:13:53 +08:00
|
|
|
# TODO: Use a vllm-specific Validation Error
|
2024-01-17 05:33:14 +00:00
|
|
|
return self.create_error_response(str(e))
|
|
|
|
|
2024-11-01 16:13:35 +08:00
|
|
|
assert len(generators) == 1
|
|
|
|
result_generator, = generators
|
|
|
|
|
2024-01-17 05:33:14 +00:00
|
|
|
# Streaming response
|
|
|
|
if request.stream:
|
|
|
|
return self.chat_completion_stream_generator(
|
2024-12-12 01:25:16 -08:00
|
|
|
request, result_generator, request_id, model_name,
|
|
|
|
conversation, tokenizer, request_metadata)
|
2024-09-04 15:18:13 -05:00
|
|
|
|
2024-08-06 22:21:41 -07:00
|
|
|
try:
|
|
|
|
return await self.chat_completion_full_generator(
|
2024-12-12 01:25:16 -08:00
|
|
|
request, result_generator, request_id, model_name,
|
|
|
|
conversation, tokenizer, request_metadata)
|
2024-08-06 22:21:41 -07:00
|
|
|
except ValueError as e:
|
|
|
|
# TODO: Use a vllm-specific Validation Error
|
|
|
|
return self.create_error_response(str(e))
|
2024-01-17 05:33:14 +00:00
|
|
|
|
|
|
|
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
|
|
|
if request.add_generation_prompt:
|
|
|
|
return self.response_role
|
2024-09-12 20:02:00 +01:00
|
|
|
return request.messages[-1]["role"]
|
2024-01-17 05:33:14 +00:00
|
|
|
|
|
|
|
async def chat_completion_stream_generator(
|
2024-07-18 00:13:30 -07:00
|
|
|
self,
|
|
|
|
request: ChatCompletionRequest,
|
|
|
|
result_generator: AsyncIterator[RequestOutput],
|
|
|
|
request_id: str,
|
2024-12-12 01:25:16 -08:00
|
|
|
model_name: str,
|
2024-07-18 00:13:30 -07:00
|
|
|
conversation: List[ConversationMessage],
|
2024-08-21 14:28:21 +08:00
|
|
|
tokenizer: AnyTokenizer,
|
2024-09-25 00:49:26 -07:00
|
|
|
request_metadata: RequestResponseMetadata,
|
2024-05-01 01:28:46 +02:00
|
|
|
) -> AsyncGenerator[str, None]:
|
2024-03-16 02:25:43 +08:00
|
|
|
created_time = int(time.time())
|
2024-08-21 14:28:21 +08:00
|
|
|
chunk_object_type: Final = "chat.completion.chunk"
|
2024-03-04 11:54:06 -08:00
|
|
|
first_iteration = True
|
2024-01-17 05:33:14 +00:00
|
|
|
|
|
|
|
# Send response for each token for each request.n (index)
|
2024-07-23 01:13:53 +08:00
|
|
|
num_choices = 1 if request.n is None else request.n
|
|
|
|
previous_num_tokens = [0] * num_choices
|
|
|
|
finish_reason_sent = [False] * num_choices
|
2024-09-12 20:02:00 +01:00
|
|
|
num_prompt_tokens = 0
|
2024-11-12 08:42:28 -08:00
|
|
|
num_cached_tokens = None
|
2024-09-12 20:02:00 +01:00
|
|
|
|
|
|
|
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
|
|
|
tool_choice_function_name = request.tool_choice.function.name
|
|
|
|
else:
|
|
|
|
tool_choice_function_name = None
|
|
|
|
|
|
|
|
# Determine whether tools are in use with "auto" tool choice
|
|
|
|
tool_choice_auto = (
|
|
|
|
not tool_choice_function_name
|
|
|
|
and self._should_stream_with_auto_tool_parsing(request))
|
|
|
|
|
2025-01-29 11:38:08 +08:00
|
|
|
should_stream_with_reasoning_parsing = (
|
|
|
|
self._should_stream_with_reasoning_parsing(request))
|
|
|
|
|
2024-09-12 20:02:00 +01:00
|
|
|
all_previous_token_ids: Optional[List[List[int]]]
|
2025-01-29 11:38:08 +08:00
|
|
|
|
|
|
|
# Only one of these will be used, thus previous_texts and
|
|
|
|
# all_previous_token_ids will not be used twice in the same iteration.
|
|
|
|
if tool_choice_auto or should_stream_with_reasoning_parsing:
|
2024-09-12 20:02:00 +01:00
|
|
|
# These are only required in "auto" tool choice case
|
|
|
|
previous_texts = [""] * num_choices
|
|
|
|
all_previous_token_ids = [[]] * num_choices
|
|
|
|
else:
|
|
|
|
previous_texts, all_previous_token_ids = None, None
|
|
|
|
|
2025-01-29 11:38:08 +08:00
|
|
|
try:
|
|
|
|
# There is no need to check if the reasoning_parser is None
|
|
|
|
# because the should_stream_with_reasoning_parsing check
|
|
|
|
# already ensures that the reasoning_parser is not None.
|
|
|
|
# but the pre-commit hook requires it.
|
|
|
|
if should_stream_with_reasoning_parsing and \
|
|
|
|
self.reasoning_parser is not None:
|
|
|
|
reasoning_parser = self.reasoning_parser(tokenizer)
|
|
|
|
except RuntimeError as e:
|
|
|
|
logger.exception("Error in reasoning parser creation.")
|
|
|
|
data = self.create_streaming_error_response(str(e))
|
|
|
|
yield f"data: {data}\n\n"
|
|
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return
|
|
|
|
|
2024-10-06 20:51:08 +08:00
|
|
|
# Prepare the tool parser if it's needed
|
|
|
|
try:
|
|
|
|
if tool_choice_auto and self.tool_parser:
|
|
|
|
tool_parsers: List[Optional[ToolParser]] = [
|
|
|
|
self.tool_parser(tokenizer)
|
|
|
|
] * num_choices
|
|
|
|
else:
|
|
|
|
tool_parsers = [None] * num_choices
|
2025-01-05 16:35:01 -05:00
|
|
|
except Exception as e:
|
2024-10-17 21:55:48 +08:00
|
|
|
logger.exception("Error in tool parser creation.")
|
2024-10-06 20:51:08 +08:00
|
|
|
data = self.create_streaming_error_response(str(e))
|
|
|
|
yield f"data: {data}\n\n"
|
|
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return
|
|
|
|
|
2024-10-15 07:19:48 +01:00
|
|
|
stream_options = request.stream_options
|
|
|
|
if stream_options:
|
|
|
|
include_usage = stream_options.include_usage
|
|
|
|
include_continuous_usage = include_usage and \
|
|
|
|
stream_options.continuous_usage_stats
|
|
|
|
else:
|
|
|
|
include_usage, include_continuous_usage = False, False
|
|
|
|
|
2024-03-04 11:54:06 -08:00
|
|
|
try:
|
|
|
|
async for res in result_generator:
|
2024-09-12 20:02:00 +01:00
|
|
|
if res.prompt_token_ids is not None:
|
|
|
|
num_prompt_tokens = len(res.prompt_token_ids)
|
2024-09-25 13:29:32 -07:00
|
|
|
if res.encoder_prompt_token_ids is not None:
|
|
|
|
num_prompt_tokens += len(res.encoder_prompt_token_ids)
|
2024-09-12 20:02:00 +01:00
|
|
|
|
2024-03-04 11:54:06 -08:00
|
|
|
# We need to do it here, because if there are exceptions in
|
|
|
|
# the result_generator, it needs to be sent as the FIRST
|
|
|
|
# response (by the try...catch).
|
|
|
|
if first_iteration:
|
2024-11-12 08:42:28 -08:00
|
|
|
num_cached_tokens = res.num_cached_tokens
|
2024-03-10 19:49:14 -07:00
|
|
|
# Send first response for each request.n (index) with
|
|
|
|
# the role
|
2024-03-04 11:54:06 -08:00
|
|
|
role = self.get_chat_request_role(request)
|
2024-09-04 15:18:13 -05:00
|
|
|
|
|
|
|
# NOTE num_choices defaults to 1 so this usually executes
|
|
|
|
# once per request
|
2024-07-23 01:13:53 +08:00
|
|
|
for i in range(num_choices):
|
2024-03-04 11:54:06 -08:00
|
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
|
|
index=i,
|
2024-09-09 09:45:11 -05:00
|
|
|
delta=DeltaMessage(
|
|
|
|
role=role,
|
|
|
|
content="",
|
|
|
|
),
|
2024-03-04 11:54:06 -08:00
|
|
|
logprobs=None,
|
|
|
|
finish_reason=None)
|
|
|
|
chunk = ChatCompletionStreamResponse(
|
|
|
|
id=request_id,
|
|
|
|
object=chunk_object_type,
|
|
|
|
created=created_time,
|
|
|
|
choices=[choice_data],
|
|
|
|
model=model_name)
|
2024-09-04 15:18:13 -05:00
|
|
|
|
2024-10-15 07:19:48 +01:00
|
|
|
# if continuous usage stats are requested, add it
|
|
|
|
if include_continuous_usage:
|
|
|
|
chunk.usage = UsageInfo(
|
|
|
|
prompt_tokens=num_prompt_tokens,
|
|
|
|
completion_tokens=0,
|
|
|
|
total_tokens=num_prompt_tokens)
|
2024-07-23 21:41:55 +03:00
|
|
|
|
2024-03-04 11:54:06 -08:00
|
|
|
data = chunk.model_dump_json(exclude_unset=True)
|
|
|
|
yield f"data: {data}\n\n"
|
|
|
|
|
2024-03-10 19:49:14 -07:00
|
|
|
# Send response to echo the input portion of the
|
|
|
|
# last message
|
2024-11-22 00:24:32 +08:00
|
|
|
if request.echo:
|
2024-10-24 01:05:49 -04:00
|
|
|
last_msg_content: Union[str, List[Dict[str, str]]] = ""
|
2024-09-11 00:49:11 +08:00
|
|
|
if conversation and "content" in conversation[
|
|
|
|
-1] and conversation[-1].get("role") == role:
|
|
|
|
last_msg_content = conversation[-1]["content"] or ""
|
2024-03-04 11:54:06 -08:00
|
|
|
|
|
|
|
if last_msg_content:
|
2024-07-23 01:13:53 +08:00
|
|
|
for i in range(num_choices):
|
2024-03-10 19:49:14 -07:00
|
|
|
choice_data = (
|
|
|
|
ChatCompletionResponseStreamChoice(
|
|
|
|
index=i,
|
|
|
|
delta=DeltaMessage(
|
|
|
|
content=last_msg_content),
|
2024-07-23 01:13:53 +08:00
|
|
|
logprobs=None,
|
2024-03-10 19:49:14 -07:00
|
|
|
finish_reason=None))
|
2024-03-04 11:54:06 -08:00
|
|
|
chunk = ChatCompletionStreamResponse(
|
|
|
|
id=request_id,
|
|
|
|
object=chunk_object_type,
|
|
|
|
created=created_time,
|
|
|
|
choices=[choice_data],
|
|
|
|
model=model_name)
|
2024-10-15 07:19:48 +01:00
|
|
|
if include_continuous_usage:
|
|
|
|
chunk.usage = UsageInfo(
|
|
|
|
prompt_tokens=num_prompt_tokens,
|
|
|
|
completion_tokens=0,
|
|
|
|
total_tokens=num_prompt_tokens)
|
2024-07-23 21:41:55 +03:00
|
|
|
|
2024-03-04 11:54:06 -08:00
|
|
|
data = chunk.model_dump_json(
|
|
|
|
exclude_unset=True)
|
|
|
|
yield f"data: {data}\n\n"
|
|
|
|
first_iteration = False
|
|
|
|
|
|
|
|
for output in res.outputs:
|
|
|
|
i = output.index
|
2024-10-04 10:36:39 +08:00
|
|
|
tool_parser = tool_parsers[i]
|
2024-03-04 11:54:06 -08:00
|
|
|
|
|
|
|
if finish_reason_sent[i]:
|
|
|
|
continue
|
|
|
|
|
2024-06-11 13:36:46 +08:00
|
|
|
if request.logprobs and request.top_logprobs is not None:
|
2024-09-12 20:02:00 +01:00
|
|
|
assert output.logprobs is not None, (
|
2024-06-11 13:36:46 +08:00
|
|
|
"Did not output logprobs")
|
2024-05-30 11:52:14 +02:00
|
|
|
logprobs = self._create_chat_logprobs(
|
2024-09-12 20:02:00 +01:00
|
|
|
token_ids=output.token_ids,
|
|
|
|
top_logprobs=output.logprobs,
|
2024-07-18 00:13:30 -07:00
|
|
|
tokenizer=tokenizer,
|
2024-05-30 02:13:22 +03:00
|
|
|
num_output_top_logprobs=request.top_logprobs,
|
2024-03-04 11:54:06 -08:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
logprobs = None
|
|
|
|
|
2024-09-12 20:02:00 +01:00
|
|
|
delta_text = output.text
|
2024-10-15 15:40:43 -07:00
|
|
|
|
|
|
|
if not delta_text and not output.token_ids and \
|
|
|
|
not previous_num_tokens[i]:
|
|
|
|
# Chunked prefill case, don't return empty chunks
|
|
|
|
continue
|
|
|
|
|
2024-09-12 20:02:00 +01:00
|
|
|
delta_message: Optional[DeltaMessage]
|
2024-06-04 01:25:29 +02:00
|
|
|
|
2024-09-04 15:18:13 -05:00
|
|
|
# handle streaming deltas for tools with named tool_choice
|
2024-09-12 20:02:00 +01:00
|
|
|
if tool_choice_function_name:
|
2024-06-04 01:25:29 +02:00
|
|
|
delta_message = DeltaMessage(tool_calls=[
|
2024-09-04 15:18:13 -05:00
|
|
|
DeltaToolCall(function=DeltaFunctionCall(
|
2024-09-12 20:02:00 +01:00
|
|
|
name=tool_choice_function_name,
|
2024-09-04 15:18:13 -05:00
|
|
|
arguments=delta_text),
|
|
|
|
index=i)
|
2024-06-04 01:25:29 +02:00
|
|
|
])
|
2024-09-04 15:18:13 -05:00
|
|
|
|
|
|
|
# handle streaming deltas for tools with "auto" tool choice
|
2024-09-12 20:02:00 +01:00
|
|
|
elif tool_choice_auto:
|
|
|
|
assert previous_texts is not None
|
|
|
|
assert all_previous_token_ids is not None
|
|
|
|
assert tool_parser is not None
|
|
|
|
#TODO optimize manipulation of these lists
|
|
|
|
previous_text = previous_texts[i]
|
|
|
|
previous_token_ids = all_previous_token_ids[i]
|
|
|
|
current_text = previous_text + delta_text
|
|
|
|
current_token_ids = previous_token_ids + list(
|
|
|
|
output.token_ids)
|
|
|
|
|
2024-09-04 15:18:13 -05:00
|
|
|
delta_message = (
|
|
|
|
tool_parser.extract_tool_calls_streaming(
|
2024-09-12 20:02:00 +01:00
|
|
|
previous_text=previous_text,
|
|
|
|
current_text=current_text,
|
2024-09-04 15:18:13 -05:00
|
|
|
delta_text=delta_text,
|
2024-09-12 20:02:00 +01:00
|
|
|
previous_token_ids=previous_token_ids,
|
|
|
|
current_token_ids=current_token_ids,
|
2024-10-04 10:36:39 +08:00
|
|
|
delta_token_ids=output.token_ids,
|
|
|
|
request=request))
|
2024-09-12 20:02:00 +01:00
|
|
|
|
|
|
|
# update the previous values for the next iteration
|
|
|
|
previous_texts[i] = current_text
|
|
|
|
all_previous_token_ids[i] = current_token_ids
|
2025-01-29 11:38:08 +08:00
|
|
|
# reasoning_content cannot be enabled with tool_choice.
|
|
|
|
# If it is, the tool_choice will be used instead.
|
|
|
|
elif self.enable_reasoning:
|
|
|
|
# handle reasoning_content delta
|
|
|
|
assert reasoning_parser is not None
|
|
|
|
assert previous_texts is not None
|
|
|
|
assert all_previous_token_ids is not None
|
|
|
|
previous_text = previous_texts[i]
|
|
|
|
previous_token_ids = all_previous_token_ids[i]
|
|
|
|
current_text = previous_text + delta_text
|
|
|
|
current_token_ids = previous_token_ids + list(
|
|
|
|
output.token_ids)
|
|
|
|
|
|
|
|
delta_message = (reasoning_parser.
|
|
|
|
extract_reasoning_content_streaming(
|
|
|
|
previous_text,
|
|
|
|
current_text,
|
|
|
|
delta_text,
|
|
|
|
previous_token_ids,
|
|
|
|
current_token_ids,
|
|
|
|
output.token_ids,
|
|
|
|
))
|
|
|
|
|
|
|
|
# update the previous values for the next iteration
|
|
|
|
previous_texts[i] = current_text
|
|
|
|
all_previous_token_ids[i] = current_token_ids
|
2024-09-04 15:18:13 -05:00
|
|
|
|
|
|
|
# handle streaming just a content delta
|
2024-06-04 01:25:29 +02:00
|
|
|
else:
|
|
|
|
delta_message = DeltaMessage(content=delta_text)
|
|
|
|
|
2024-09-04 15:18:13 -05:00
|
|
|
# set the previous values for the next iteration
|
2024-09-12 20:02:00 +01:00
|
|
|
previous_num_tokens[i] += len(output.token_ids)
|
2024-09-04 15:18:13 -05:00
|
|
|
|
|
|
|
# if the message delta is None (e.g. because it was a
|
|
|
|
# "control token" for tool calls or the parser otherwise
|
|
|
|
# wasn't ready to send a token, then
|
|
|
|
# get the next token without streaming a chunk
|
|
|
|
if delta_message is None:
|
|
|
|
continue
|
|
|
|
|
2024-03-04 11:54:06 -08:00
|
|
|
if output.finish_reason is None:
|
|
|
|
# Send token-by-token response for each request.n
|
|
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
|
|
index=i,
|
2024-06-04 01:25:29 +02:00
|
|
|
delta=delta_message,
|
2024-03-04 11:54:06 -08:00
|
|
|
logprobs=logprobs,
|
|
|
|
finish_reason=None)
|
2024-09-04 15:18:13 -05:00
|
|
|
|
|
|
|
# if the model is finished generating
|
2024-03-04 11:54:06 -08:00
|
|
|
else:
|
2024-09-04 15:18:13 -05:00
|
|
|
# check to make sure we haven't "forgotten" to stream
|
|
|
|
# any tokens that were generated but previously
|
|
|
|
# matched by partial json parsing
|
|
|
|
# only happens if we are NOT using guided decoding
|
2024-10-11 22:24:26 -03:00
|
|
|
auto_tools_called = False
|
2024-09-04 15:18:13 -05:00
|
|
|
if tool_parser:
|
2024-10-11 22:24:26 -03:00
|
|
|
auto_tools_called = len(
|
|
|
|
tool_parser.prev_tool_call_arr) > 0
|
|
|
|
index = len(tool_parser.prev_tool_call_arr
|
|
|
|
) - 1 if auto_tools_called else 0
|
2024-09-04 15:18:13 -05:00
|
|
|
else:
|
|
|
|
index = 0
|
|
|
|
|
|
|
|
if self._should_check_for_unstreamed_tool_arg_tokens(
|
|
|
|
delta_message, output) and tool_parser:
|
2024-12-11 17:10:12 -08:00
|
|
|
latest_delta_len = 0
|
|
|
|
if ((isinstance(
|
|
|
|
delta_message.tool_calls[0].function,
|
|
|
|
DeltaFunctionCall)) and isinstance(
|
|
|
|
delta_message.tool_calls[0].function.
|
|
|
|
arguments, str)):
|
|
|
|
latest_delta_len = len(
|
|
|
|
delta_message.tool_calls[0].function.
|
|
|
|
arguments)
|
|
|
|
|
2024-09-04 15:18:13 -05:00
|
|
|
# get the expected call based on partial JSON
|
|
|
|
# parsing which "autocompletes" the JSON
|
|
|
|
expected_call = json.dumps(
|
|
|
|
tool_parser.prev_tool_call_arr[index].get(
|
2024-12-11 17:10:12 -08:00
|
|
|
"arguments", {}),
|
|
|
|
ensure_ascii=False)
|
2024-09-04 15:18:13 -05:00
|
|
|
|
2024-09-12 20:02:00 +01:00
|
|
|
# get what we've streamed so far for arguments
|
2024-09-04 15:18:13 -05:00
|
|
|
# for the current tool
|
|
|
|
actual_call = tool_parser.streamed_args_for_tool[
|
|
|
|
index]
|
2024-12-11 17:10:12 -08:00
|
|
|
if (latest_delta_len > 0):
|
|
|
|
actual_call = actual_call[:-latest_delta_len]
|
2024-09-04 15:18:13 -05:00
|
|
|
|
|
|
|
# check to see if there's anything left to stream
|
|
|
|
remaining_call = expected_call.replace(
|
|
|
|
actual_call, "", 1)
|
|
|
|
# set that as a delta message
|
|
|
|
delta_message = DeltaMessage(tool_calls=[
|
|
|
|
DeltaToolCall(index=index,
|
|
|
|
function=DeltaFunctionCall(
|
|
|
|
arguments=remaining_call).
|
|
|
|
model_dump(exclude_none=True))
|
|
|
|
])
|
|
|
|
|
2024-03-04 11:54:06 -08:00
|
|
|
# Send the finish response for each request.n only once
|
|
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
|
|
index=i,
|
2024-06-04 01:25:29 +02:00
|
|
|
delta=delta_message,
|
2024-03-04 11:54:06 -08:00
|
|
|
logprobs=logprobs,
|
2024-09-04 15:18:13 -05:00
|
|
|
finish_reason=output.finish_reason
|
2024-10-11 22:24:26 -03:00
|
|
|
if not auto_tools_called else "tool_calls",
|
2024-03-25 17:31:32 -07:00
|
|
|
stop_reason=output.stop_reason)
|
2024-10-15 07:19:48 +01:00
|
|
|
|
2024-03-04 11:54:06 -08:00
|
|
|
finish_reason_sent[i] = True
|
2024-06-07 06:29:24 +03:00
|
|
|
|
2024-10-15 07:19:48 +01:00
|
|
|
chunk = ChatCompletionStreamResponse(
|
|
|
|
id=request_id,
|
|
|
|
object=chunk_object_type,
|
|
|
|
created=created_time,
|
|
|
|
choices=[choice_data],
|
|
|
|
model=model_name)
|
|
|
|
|
|
|
|
# handle usage stats if requested & if continuous
|
|
|
|
if include_continuous_usage:
|
|
|
|
completion_tokens = previous_num_tokens[i]
|
|
|
|
chunk.usage = UsageInfo(
|
|
|
|
prompt_tokens=num_prompt_tokens,
|
|
|
|
completion_tokens=completion_tokens,
|
|
|
|
total_tokens=num_prompt_tokens + completion_tokens,
|
|
|
|
)
|
|
|
|
|
|
|
|
data = chunk.model_dump_json(exclude_unset=True)
|
|
|
|
yield f"data: {data}\n\n"
|
|
|
|
|
2024-09-04 15:18:13 -05:00
|
|
|
# once the final token is handled, if stream_options.include_usage
|
|
|
|
# is sent, send the usage
|
2024-10-15 07:19:48 +01:00
|
|
|
if include_usage:
|
|
|
|
completion_tokens = sum(previous_num_tokens)
|
2024-11-12 08:42:28 -08:00
|
|
|
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
|
|
|
completion_tokens=completion_tokens,
|
|
|
|
total_tokens=num_prompt_tokens +
|
|
|
|
completion_tokens)
|
|
|
|
if self.enable_prompt_tokens_details and num_cached_tokens:
|
|
|
|
final_usage.prompt_tokens_details = PromptTokenUsageInfo(
|
|
|
|
cached_tokens=num_cached_tokens)
|
2024-06-10 17:22:09 +03:00
|
|
|
|
|
|
|
final_usage_chunk = ChatCompletionStreamResponse(
|
|
|
|
id=request_id,
|
|
|
|
object=chunk_object_type,
|
|
|
|
created=created_time,
|
|
|
|
choices=[],
|
|
|
|
model=model_name,
|
|
|
|
usage=final_usage)
|
|
|
|
final_usage_data = (final_usage_chunk.model_dump_json(
|
|
|
|
exclude_unset=True, exclude_none=True))
|
|
|
|
yield f"data: {final_usage_data}\n\n"
|
2024-06-07 06:29:24 +03:00
|
|
|
|
2024-09-25 00:49:26 -07:00
|
|
|
# report to FastAPI middleware aggregate usage across all choices
|
|
|
|
num_completion_tokens = sum(previous_num_tokens)
|
|
|
|
request_metadata.final_usage_info = UsageInfo(
|
|
|
|
prompt_tokens=num_prompt_tokens,
|
|
|
|
completion_tokens=num_completion_tokens,
|
|
|
|
total_tokens=num_prompt_tokens + num_completion_tokens)
|
|
|
|
|
2025-01-05 16:35:01 -05:00
|
|
|
except Exception as e:
|
2024-03-04 11:54:06 -08:00
|
|
|
# TODO: Use a vllm-specific Validation Error
|
2024-10-17 21:55:48 +08:00
|
|
|
logger.exception("Error in chat completion stream generator.")
|
2024-03-04 11:54:06 -08:00
|
|
|
data = self.create_streaming_error_response(str(e))
|
|
|
|
yield f"data: {data}\n\n"
|
2024-01-17 05:33:14 +00:00
|
|
|
# Send the final done message after all response.n are finished
|
|
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
|
|
async def chat_completion_full_generator(
|
2024-07-18 00:13:30 -07:00
|
|
|
self,
|
|
|
|
request: ChatCompletionRequest,
|
|
|
|
result_generator: AsyncIterator[RequestOutput],
|
|
|
|
request_id: str,
|
2024-12-12 01:25:16 -08:00
|
|
|
model_name: str,
|
2024-07-18 00:13:30 -07:00
|
|
|
conversation: List[ConversationMessage],
|
2024-08-21 14:28:21 +08:00
|
|
|
tokenizer: AnyTokenizer,
|
2024-09-25 00:49:26 -07:00
|
|
|
request_metadata: RequestResponseMetadata,
|
2024-05-01 01:28:46 +02:00
|
|
|
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
2024-01-17 05:33:14 +00:00
|
|
|
|
2024-03-16 02:25:43 +08:00
|
|
|
created_time = int(time.time())
|
2024-04-27 13:08:24 +08:00
|
|
|
final_res: Optional[RequestOutput] = None
|
2024-01-17 05:33:14 +00:00
|
|
|
|
2024-08-06 22:21:41 -07:00
|
|
|
try:
|
|
|
|
async for res in result_generator:
|
|
|
|
final_res = res
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
return self.create_error_response("Client disconnected")
|
2024-11-01 16:13:35 +08:00
|
|
|
except ValueError as e:
|
|
|
|
# TODO: Use a vllm-specific Validation Error
|
|
|
|
return self.create_error_response(str(e))
|
2024-08-06 22:21:41 -07:00
|
|
|
|
2024-01-17 05:33:14 +00:00
|
|
|
assert final_res is not None
|
|
|
|
|
2024-06-15 12:45:31 +08:00
|
|
|
choices: List[ChatCompletionResponseChoice] = []
|
2024-02-25 18:39:34 -08:00
|
|
|
|
2024-01-17 05:33:14 +00:00
|
|
|
role = self.get_chat_request_role(request)
|
|
|
|
for output in final_res.outputs:
|
2024-02-25 18:39:34 -08:00
|
|
|
token_ids = output.token_ids
|
2024-06-11 13:36:46 +08:00
|
|
|
out_logprobs = output.logprobs
|
2024-02-25 18:39:34 -08:00
|
|
|
|
2024-06-11 13:36:46 +08:00
|
|
|
if request.logprobs and request.top_logprobs is not None:
|
|
|
|
assert out_logprobs is not None, "Did not output logprobs"
|
2024-05-30 11:52:14 +02:00
|
|
|
logprobs = self._create_chat_logprobs(
|
2024-02-25 18:39:34 -08:00
|
|
|
token_ids=token_ids,
|
2024-06-11 13:36:46 +08:00
|
|
|
top_logprobs=out_logprobs,
|
2024-05-30 02:13:22 +03:00
|
|
|
num_output_top_logprobs=request.top_logprobs,
|
2024-07-18 00:13:30 -07:00
|
|
|
tokenizer=tokenizer,
|
2024-02-25 18:39:34 -08:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
logprobs = None
|
|
|
|
|
2025-01-29 11:38:08 +08:00
|
|
|
should_stream_with_reasoning_parsing = (
|
|
|
|
self._should_stream_with_reasoning_parsing(request))
|
|
|
|
|
2024-10-11 22:24:26 -03:00
|
|
|
# In the OpenAI API the finish_reason is "tools_called"
|
|
|
|
# if the tool choice is auto and the model produced a tool
|
|
|
|
# call. The same is not true for named function calls
|
|
|
|
auto_tools_called = False
|
2024-09-04 15:18:13 -05:00
|
|
|
|
2025-01-29 11:38:08 +08:00
|
|
|
if should_stream_with_reasoning_parsing and \
|
|
|
|
self.reasoning_parser is not None:
|
|
|
|
try:
|
|
|
|
reasoning_parser = self.reasoning_parser(tokenizer)
|
|
|
|
except RuntimeError as e:
|
|
|
|
logger.exception("Error in reasoning parser creation.")
|
|
|
|
return self.create_error_response(str(e))
|
|
|
|
|
|
|
|
reasoning_content, content = (
|
|
|
|
reasoning_parser.extract_reasoning_content(
|
|
|
|
output.text, request=request))
|
|
|
|
|
|
|
|
if reasoning_content:
|
|
|
|
message = ChatMessage(role=role,
|
|
|
|
content=content,
|
|
|
|
reasoning_content=reasoning_content)
|
|
|
|
else:
|
|
|
|
message = ChatMessage(role=role, content=output.text)
|
|
|
|
|
2024-09-04 15:18:13 -05:00
|
|
|
# if auto tools are not enabled, and a named tool choice using
|
|
|
|
# outlines is not being used
|
2025-01-29 11:38:08 +08:00
|
|
|
elif (not self.enable_auto_tools
|
|
|
|
or not self.tool_parser) and not isinstance(
|
|
|
|
request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
2024-09-04 15:18:13 -05:00
|
|
|
message = ChatMessage(role=role, content=output.text)
|
|
|
|
|
|
|
|
# if the request uses tools and specified a tool choice
|
|
|
|
elif request.tool_choice and type(
|
2024-06-04 01:25:29 +02:00
|
|
|
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
2024-09-04 15:18:13 -05:00
|
|
|
|
2025-02-12 11:29:56 -05:00
|
|
|
tool_call_class = MistralToolCall if isinstance(
|
|
|
|
tokenizer, MistralTokenizer) else ToolCall
|
2024-06-04 01:25:29 +02:00
|
|
|
message = ChatMessage(
|
|
|
|
role=role,
|
|
|
|
content="",
|
|
|
|
tool_calls=[
|
2025-02-12 11:29:56 -05:00
|
|
|
tool_call_class(function=FunctionCall(
|
2024-06-04 01:25:29 +02:00
|
|
|
name=request.tool_choice.function.name,
|
|
|
|
arguments=output.text))
|
|
|
|
])
|
2024-09-04 15:18:13 -05:00
|
|
|
|
|
|
|
# if the request doesn't use tool choice
|
|
|
|
# OR specifies to not use a tool
|
2024-06-04 01:25:29 +02:00
|
|
|
elif not request.tool_choice or request.tool_choice == "none":
|
2024-09-04 15:18:13 -05:00
|
|
|
|
|
|
|
message = ChatMessage(role=role, content=output.text)
|
|
|
|
|
|
|
|
# handle when there are tools and tool choice is auto
|
|
|
|
elif request.tools and (
|
|
|
|
request.tool_choice == "auto"
|
|
|
|
or request.tool_choice is None) and self.enable_auto_tools \
|
|
|
|
and self.tool_parser:
|
|
|
|
|
2024-10-06 20:51:08 +08:00
|
|
|
try:
|
|
|
|
tool_parser = self.tool_parser(tokenizer)
|
|
|
|
except RuntimeError as e:
|
2024-10-17 21:55:48 +08:00
|
|
|
logger.exception("Error in tool parser creation.")
|
2024-10-06 20:51:08 +08:00
|
|
|
return self.create_error_response(str(e))
|
|
|
|
|
2024-10-04 10:36:39 +08:00
|
|
|
tool_call_info = tool_parser.extract_tool_calls(
|
|
|
|
output.text, request=request)
|
2024-10-11 22:24:26 -03:00
|
|
|
# In the OpenAI API the finish_reason is "tools_called"
|
|
|
|
# if the tool choice is auto and the model produced a tool
|
|
|
|
# call. The same is not true for named function calls
|
|
|
|
auto_tools_called = tool_call_info.tools_called
|
2024-09-04 15:18:13 -05:00
|
|
|
if tool_call_info.tools_called:
|
|
|
|
message = ChatMessage(role=role,
|
|
|
|
content=tool_call_info.content,
|
|
|
|
tool_calls=tool_call_info.tool_calls)
|
|
|
|
|
|
|
|
else:
|
|
|
|
# FOR NOW make it a chat message; we will have to detect
|
|
|
|
# the type to make it later.
|
|
|
|
message = ChatMessage(role=role, content=output.text)
|
|
|
|
|
|
|
|
# undetermined case that is still important to handle
|
|
|
|
else:
|
|
|
|
logger.error(
|
|
|
|
"Error in chat_completion_full_generator - cannot determine"
|
|
|
|
" if tools should be extracted. Returning a standard chat "
|
|
|
|
"completion.")
|
2024-06-04 01:25:29 +02:00
|
|
|
message = ChatMessage(role=role, content=output.text)
|
|
|
|
|
2024-01-17 05:33:14 +00:00
|
|
|
choice_data = ChatCompletionResponseChoice(
|
|
|
|
index=output.index,
|
2024-06-04 01:25:29 +02:00
|
|
|
message=message,
|
2024-02-25 18:39:34 -08:00
|
|
|
logprobs=logprobs,
|
2024-10-11 22:24:26 -03:00
|
|
|
finish_reason="tool_calls" if auto_tools_called else
|
2024-09-04 15:18:13 -05:00
|
|
|
output.finish_reason if output.finish_reason else "stop",
|
2024-05-30 11:52:14 +02:00
|
|
|
stop_reason=output.stop_reason)
|
2024-01-17 05:33:14 +00:00
|
|
|
choices.append(choice_data)
|
|
|
|
|
2024-11-22 00:24:32 +08:00
|
|
|
if request.echo:
|
2024-10-24 01:05:49 -04:00
|
|
|
last_msg_content: Union[str, List[Dict[str, str]]] = ""
|
2024-09-11 00:49:11 +08:00
|
|
|
if conversation and "content" in conversation[-1] and conversation[
|
|
|
|
-1].get("role") == role:
|
2024-09-04 15:18:13 -05:00
|
|
|
last_msg_content = conversation[-1]["content"] or ""
|
2024-10-24 01:05:49 -04:00
|
|
|
if isinstance(last_msg_content, list):
|
|
|
|
last_msg_content = "\n".join(msg['text']
|
|
|
|
for msg in last_msg_content)
|
2024-01-17 05:33:14 +00:00
|
|
|
|
|
|
|
for choice in choices:
|
2024-09-04 15:18:13 -05:00
|
|
|
full_message = last_msg_content + (choice.message.content
|
|
|
|
or "")
|
2024-01-17 05:33:14 +00:00
|
|
|
choice.message.content = full_message
|
|
|
|
|
2024-09-12 20:02:00 +01:00
|
|
|
assert final_res.prompt_token_ids is not None
|
2024-01-17 05:33:14 +00:00
|
|
|
num_prompt_tokens = len(final_res.prompt_token_ids)
|
2024-09-26 15:47:00 -07:00
|
|
|
if final_res.encoder_prompt_token_ids is not None:
|
|
|
|
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
2024-01-17 05:33:14 +00:00
|
|
|
num_generated_tokens = sum(
|
|
|
|
len(output.token_ids) for output in final_res.outputs)
|
2024-11-12 08:42:28 -08:00
|
|
|
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
|
|
|
completion_tokens=num_generated_tokens,
|
|
|
|
total_tokens=num_prompt_tokens +
|
|
|
|
num_generated_tokens)
|
|
|
|
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
|
|
|
|
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
|
|
|
cached_tokens=final_res.num_cached_tokens)
|
2024-09-25 00:49:26 -07:00
|
|
|
|
|
|
|
request_metadata.final_usage_info = usage
|
|
|
|
|
2024-01-17 05:33:14 +00:00
|
|
|
response = ChatCompletionResponse(
|
|
|
|
id=request_id,
|
|
|
|
created=created_time,
|
|
|
|
model=model_name,
|
|
|
|
choices=choices,
|
|
|
|
usage=usage,
|
2024-08-16 12:38:08 +10:00
|
|
|
prompt_logprobs=final_res.prompt_logprobs,
|
2024-01-17 05:33:14 +00:00
|
|
|
)
|
|
|
|
|
2024-05-16 05:58:46 +08:00
|
|
|
return response
|
2024-05-30 11:52:14 +02:00
|
|
|
|
|
|
|
def _get_top_logprobs(
|
2024-07-18 00:13:30 -07:00
|
|
|
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
|
2024-08-21 14:28:21 +08:00
|
|
|
tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
|
2024-05-30 11:52:14 +02:00
|
|
|
return [
|
2024-07-24 18:51:00 -07:00
|
|
|
ChatCompletionLogProb(token=(token := self._get_decoded_token(
|
|
|
|
p[1],
|
|
|
|
p[0],
|
|
|
|
tokenizer,
|
|
|
|
return_as_token_id=self.return_tokens_as_token_ids)),
|
|
|
|
logprob=max(p[1].logprob, -9999.0),
|
|
|
|
bytes=list(
|
|
|
|
token.encode("utf-8", errors="replace")))
|
2024-05-30 11:52:14 +02:00
|
|
|
for i, p in enumerate(logprobs.items())
|
|
|
|
if top_logprobs and i < top_logprobs
|
|
|
|
]
|
|
|
|
|
|
|
|
def _create_chat_logprobs(
|
|
|
|
self,
|
|
|
|
token_ids: GenericSequence[int],
|
|
|
|
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
2024-08-21 14:28:21 +08:00
|
|
|
tokenizer: AnyTokenizer,
|
2024-05-30 11:52:14 +02:00
|
|
|
num_output_top_logprobs: Optional[int] = None,
|
|
|
|
) -> ChatCompletionLogProbs:
|
|
|
|
"""Create OpenAI-style logprobs."""
|
2024-08-21 14:28:21 +08:00
|
|
|
logprobs_content: List[ChatCompletionLogProbsContent] = []
|
2024-05-30 11:52:14 +02:00
|
|
|
|
|
|
|
for i, token_id in enumerate(token_ids):
|
|
|
|
step_top_logprobs = top_logprobs[i]
|
|
|
|
if step_top_logprobs is None:
|
2024-07-18 00:13:30 -07:00
|
|
|
token = tokenizer.decode(token_id)
|
2024-07-24 18:51:00 -07:00
|
|
|
if self.return_tokens_as_token_ids:
|
|
|
|
token = f"token_id:{token_id}"
|
2024-08-21 14:28:21 +08:00
|
|
|
|
2024-05-30 11:52:14 +02:00
|
|
|
logprobs_content.append(
|
|
|
|
ChatCompletionLogProbsContent(
|
2024-07-18 00:13:30 -07:00
|
|
|
token=token,
|
2024-08-21 14:28:21 +08:00
|
|
|
bytes=list(token.encode("utf-8", errors="replace")),
|
|
|
|
))
|
2024-05-30 11:52:14 +02:00
|
|
|
else:
|
2024-08-21 14:28:21 +08:00
|
|
|
step_token = step_top_logprobs[token_id]
|
|
|
|
step_decoded = step_token.decoded_token
|
|
|
|
|
2024-05-30 11:52:14 +02:00
|
|
|
logprobs_content.append(
|
|
|
|
ChatCompletionLogProbsContent(
|
2024-07-24 18:51:00 -07:00
|
|
|
token=self._get_decoded_token(
|
2024-08-21 14:28:21 +08:00
|
|
|
step_token,
|
|
|
|
token_id,
|
|
|
|
tokenizer,
|
|
|
|
self.return_tokens_as_token_ids,
|
|
|
|
),
|
|
|
|
logprob=max(step_token.logprob, -9999.0),
|
|
|
|
bytes=None if step_decoded is None else list(
|
|
|
|
step_decoded.encode("utf-8", errors="replace")),
|
2024-05-30 11:52:14 +02:00
|
|
|
top_logprobs=self._get_top_logprobs(
|
2024-08-21 14:28:21 +08:00
|
|
|
step_top_logprobs,
|
|
|
|
num_output_top_logprobs,
|
|
|
|
tokenizer,
|
|
|
|
),
|
|
|
|
))
|
2024-05-30 11:52:14 +02:00
|
|
|
|
|
|
|
return ChatCompletionLogProbs(content=logprobs_content)
|
2024-09-04 15:18:13 -05:00
|
|
|
|
|
|
|
def _should_stream_with_auto_tool_parsing(self,
|
|
|
|
request: ChatCompletionRequest):
|
|
|
|
"""
|
|
|
|
Utility function to check if streamed tokens should go through the tool
|
|
|
|
call parser that was configured.
|
|
|
|
|
|
|
|
We only want to do this IF user-provided tools are set, a tool parser
|
|
|
|
is configured, "auto" tool choice is enabled, and the request's tool
|
|
|
|
choice field indicates that "auto" tool choice should be used.
|
|
|
|
"""
|
|
|
|
return (request.tools and self.tool_parser and self.enable_auto_tools
|
|
|
|
and request.tool_choice in ['auto', None])
|
|
|
|
|
2025-01-29 11:38:08 +08:00
|
|
|
def _should_stream_with_reasoning_parsing(self,
|
|
|
|
request: ChatCompletionRequest):
|
|
|
|
"""
|
|
|
|
Utility function to check if streamed tokens should go through the
|
|
|
|
reasoning parser that was configured.
|
|
|
|
|
|
|
|
We only want to do this IF reasoning is enabled and a reasoning
|
|
|
|
parser is configured.
|
|
|
|
"""
|
|
|
|
return self.enable_reasoning and self.reasoning_parser is not None
|
|
|
|
|
2024-09-04 15:18:13 -05:00
|
|
|
def _should_check_for_unstreamed_tool_arg_tokens(
|
|
|
|
self,
|
|
|
|
delta_message: Optional[DeltaMessage],
|
|
|
|
output: CompletionOutput,
|
|
|
|
) -> bool:
|
|
|
|
"""
|
|
|
|
Check to see if we should check for unstreamed tool arguments tokens.
|
|
|
|
This is only applicable when auto tool parsing is enabled, the delta
|
|
|
|
is a tool call with arguments.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# yapf: disable
|
|
|
|
return bool(
|
|
|
|
# if there is a delta message that includes tool calls which
|
|
|
|
# include a function that has arguments
|
2024-09-12 20:02:00 +01:00
|
|
|
output.finish_reason is not None
|
|
|
|
and self.enable_auto_tools and self.tool_parser and delta_message
|
2024-09-04 15:18:13 -05:00
|
|
|
and delta_message.tool_calls and delta_message.tool_calls[0]
|
|
|
|
and delta_message.tool_calls[0].function
|
|
|
|
and delta_message.tool_calls[0].function.arguments is not None
|
|
|
|
)
|