diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d4d0cfa4..dd0b67df 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -39,7 +39,8 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, - truncate_tool_call_ids) + truncate_tool_call_ids, + validate_request_params) logger = init_logger(__name__) @@ -159,6 +160,7 @@ class OpenAIServingChat(OpenAIServing): # for more info: see comment in `maybe_serialize_tool_calls` maybe_serialize_tool_calls(request) truncate_tool_call_ids(request) + validate_request_params(request) if (request.tool_choice == "auto" and not (self.enable_auto_tools and tool_parser is not None) diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py index c12388d9..7aac29a6 100644 --- a/vllm/transformers_utils/tokenizers/__init__.py +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 from .mistral import (MistralTokenizer, maybe_serialize_tool_calls, - truncate_tool_call_ids) + truncate_tool_call_ids, validate_request_params) __all__ = [ - "MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids" + "MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids", + "validate_request_params" ] diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index d893431f..58a114fa 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -98,6 +98,13 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"): request.messages[i]["tool_call_id"] = tool_call_id +def validate_request_params(request: "ChatCompletionRequest"): + if (request.skip_special_tokens is not None + and not request.skip_special_tokens): + raise ValueError("skip_special_tokens=False is not supported " + "for Mistral tokenizers.") + + def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]: repo_cache = os.path.join( huggingface_hub.constants.HF_HUB_CACHE,