[Frontend] Generate valid tool call IDs when using tokenizer-mode=mistral
(#12332)
This commit is contained in:
parent
985b4a2b19
commit
314cfade02
0
tests/mistral_tool_use/__init__.py
Normal file
0
tests/mistral_tool_use/__init__.py
Normal file
40
tests/mistral_tool_use/conftest.py
Normal file
40
tests/mistral_tool_use/conftest.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .utils import ARGS, CONFIGS, ServerConfig
|
||||||
|
|
||||||
|
|
||||||
|
# for each server config, download the model and return the config
|
||||||
|
@pytest.fixture(scope="session", params=CONFIGS.keys())
|
||||||
|
def server_config(request):
|
||||||
|
config = CONFIGS[request.param]
|
||||||
|
|
||||||
|
if current_platform.is_rocm() and not config.get("supports_rocm", True):
|
||||||
|
pytest.skip("The {} model can't be tested on the ROCm platform".format(
|
||||||
|
config["model"]))
|
||||||
|
|
||||||
|
# download model and tokenizer using transformers
|
||||||
|
snapshot_download(config["model"])
|
||||||
|
yield CONFIGS[request.param]
|
||||||
|
|
||||||
|
|
||||||
|
# run this for each server config
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def server(request, server_config: ServerConfig):
|
||||||
|
model = server_config["model"]
|
||||||
|
args_for_model = server_config["arguments"]
|
||||||
|
with RemoteOpenAIServer(model, ARGS + args_for_model,
|
||||||
|
max_wait_seconds=480) as server:
|
||||||
|
yield server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(server: RemoteOpenAIServer):
|
||||||
|
async with server.get_async_client() as async_client:
|
||||||
|
yield async_client
|
29
tests/mistral_tool_use/test_mistral_tool_calls.py
Normal file
29
tests/mistral_tool_use/test_mistral_tool_calls.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL
|
||||||
|
|
||||||
|
|
||||||
|
# test: a tool_choice with mistral-tokenizer results in an ID of length 9
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI):
|
||||||
|
models = await client.models.list()
|
||||||
|
model_name: str = models.data[0].id
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
messages=MESSAGES_ASKING_FOR_TOOLS,
|
||||||
|
temperature=0,
|
||||||
|
max_completion_tokens=100,
|
||||||
|
model=model_name,
|
||||||
|
tools=[WEATHER_TOOL],
|
||||||
|
tool_choice=WEATHER_TOOL,
|
||||||
|
logprobs=False)
|
||||||
|
|
||||||
|
choice = chat_completion.choices[0]
|
||||||
|
|
||||||
|
assert choice.finish_reason != "tool_calls" # "stop" or "length"
|
||||||
|
assert choice.message.role == "assistant"
|
||||||
|
assert choice.message.tool_calls is None \
|
||||||
|
or len(choice.message.tool_calls) == 1
|
||||||
|
assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral
|
33
tests/mistral_tool_use/utils.py
Normal file
33
tests/mistral_tool_use/utils.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class ServerConfig(TypedDict, total=False):
|
||||||
|
model: str
|
||||||
|
arguments: List[str]
|
||||||
|
system_prompt: Optional[str]
|
||||||
|
supports_parallel: Optional[bool]
|
||||||
|
supports_rocm: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
|
ARGS: List[str] = ["--max-model-len", "1024"]
|
||||||
|
|
||||||
|
CONFIGS: Dict[str, ServerConfig] = {
|
||||||
|
"mistral": {
|
||||||
|
"model":
|
||||||
|
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||||
|
"arguments": [
|
||||||
|
"--tokenizer-mode", "mistral",
|
||||||
|
"--ignore-patterns=\"consolidated.safetensors\""
|
||||||
|
],
|
||||||
|
"system_prompt":
|
||||||
|
"You are a helpful assistant with access to tools. If a tool"
|
||||||
|
" that you have would be helpful to answer a user query, "
|
||||||
|
"call the tool. Otherwise, answer the user's query directly "
|
||||||
|
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||||
|
"to the user's question - just respond to it normally."
|
||||||
|
},
|
||||||
|
}
|
@ -28,12 +28,15 @@ from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
|
|||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||||
|
MistralToolCall)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
|
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
|
||||||
|
truncate_tool_call_ids)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -150,11 +153,12 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"tool_choice = \"required\" is not supported!")
|
"tool_choice = \"required\" is not supported!")
|
||||||
|
|
||||||
# because of issues with pydantic we need to potentially
|
|
||||||
# re-serialize the tool_calls field of the request
|
|
||||||
# for more info: see comment in `maybe_serialize_tool_calls`
|
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
|
# because of issues with pydantic we need to potentially
|
||||||
|
# re-serialize the tool_calls field of the request
|
||||||
|
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||||
maybe_serialize_tool_calls(request)
|
maybe_serialize_tool_calls(request)
|
||||||
|
truncate_tool_call_ids(request)
|
||||||
|
|
||||||
if (request.tool_choice == "auto" and
|
if (request.tool_choice == "auto" and
|
||||||
not (self.enable_auto_tools and tool_parser is not None)
|
not (self.enable_auto_tools and tool_parser is not None)
|
||||||
@ -745,11 +749,13 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
elif request.tool_choice and type(
|
elif request.tool_choice and type(
|
||||||
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||||
|
|
||||||
|
tool_call_class = MistralToolCall if isinstance(
|
||||||
|
tokenizer, MistralTokenizer) else ToolCall
|
||||||
message = ChatMessage(
|
message = ChatMessage(
|
||||||
role=role,
|
role=role,
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ToolCall(function=FunctionCall(
|
tool_call_class(function=FunctionCall(
|
||||||
name=request.tool_choice.function.name,
|
name=request.tool_choice.function.name,
|
||||||
arguments=output.text))
|
arguments=output.text))
|
||||||
])
|
])
|
||||||
|
@ -33,7 +33,7 @@ class MistralToolCall(ToolCall):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_random_id():
|
def generate_random_id():
|
||||||
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
|
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
|
||||||
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
|
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
|
||||||
return "".join(choices(ALPHANUMERIC, k=9))
|
return "".join(choices(ALPHANUMERIC, k=9))
|
||||||
|
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from .mistral import MistralTokenizer, maybe_serialize_tool_calls
|
from .mistral import (MistralTokenizer, maybe_serialize_tool_calls,
|
||||||
|
truncate_tool_call_ids)
|
||||||
|
|
||||||
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]
|
__all__ = [
|
||||||
|
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids"
|
||||||
|
]
|
||||||
|
@ -68,6 +68,36 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
|
|||||||
request.messages[i]["tool_calls"] = validated_tool_calls
|
request.messages[i]["tool_calls"] = validated_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_tool_call_ids(request: "ChatCompletionRequest"):
|
||||||
|
"""Truncates tool call IDs for Mistral's ID requirements."""
|
||||||
|
for i, message in enumerate(request.messages):
|
||||||
|
if message.get("role") == 'assistant':
|
||||||
|
tool_calls = message.get("tool_calls", [])
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
if len(tool_call["id"]) > 9:
|
||||||
|
logger.warning(
|
||||||
|
"Truncating tool call ID: %s to %s",
|
||||||
|
tool_call["id"],
|
||||||
|
tool_call["id"][-9:],
|
||||||
|
)
|
||||||
|
tool_call["id"] = tool_call["id"][-9:]
|
||||||
|
|
||||||
|
request.messages[i]["tool_calls"] = tool_calls
|
||||||
|
|
||||||
|
elif message.get("role") in {"tool_results", "tool"}:
|
||||||
|
if "tool_call_id" in message:
|
||||||
|
tool_call_id = message["tool_call_id"]
|
||||||
|
|
||||||
|
if len(tool_call_id) > 9:
|
||||||
|
logger.warning(
|
||||||
|
"Truncating tool_call_id: %s to %s",
|
||||||
|
tool_call_id,
|
||||||
|
tool_call_id[-9:],
|
||||||
|
)
|
||||||
|
tool_call_id = tool_call_id[-9:]
|
||||||
|
request.messages[i]["tool_call_id"] = tool_call_id
|
||||||
|
|
||||||
|
|
||||||
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
|
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
|
||||||
repo_cache = os.path.join(
|
repo_cache = os.path.join(
|
||||||
huggingface_hub.constants.HF_HUB_CACHE,
|
huggingface_hub.constants.HF_HUB_CACHE,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user