[Frontend] Generate valid tool call IDs when using tokenizer-mode=mistral
(#12332)
This commit is contained in:
parent
985b4a2b19
commit
314cfade02
0
tests/mistral_tool_use/__init__.py
Normal file
0
tests/mistral_tool_use/__init__.py
Normal file
40
tests/mistral_tool_use/conftest.py
Normal file
40
tests/mistral_tool_use/conftest.py
Normal file
@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import ARGS, CONFIGS, ServerConfig
|
||||
|
||||
|
||||
# for each server config, download the model and return the config
|
||||
@pytest.fixture(scope="session", params=CONFIGS.keys())
|
||||
def server_config(request):
|
||||
config = CONFIGS[request.param]
|
||||
|
||||
if current_platform.is_rocm() and not config.get("supports_rocm", True):
|
||||
pytest.skip("The {} model can't be tested on the ROCm platform".format(
|
||||
config["model"]))
|
||||
|
||||
# download model and tokenizer using transformers
|
||||
snapshot_download(config["model"])
|
||||
yield CONFIGS[request.param]
|
||||
|
||||
|
||||
# run this for each server config
|
||||
@pytest.fixture(scope="session")
|
||||
def server(request, server_config: ServerConfig):
|
||||
model = server_config["model"]
|
||||
args_for_model = server_config["arguments"]
|
||||
with RemoteOpenAIServer(model, ARGS + args_for_model,
|
||||
max_wait_seconds=480) as server:
|
||||
yield server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server: RemoteOpenAIServer):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
29
tests/mistral_tool_use/test_mistral_tool_calls.py
Normal file
29
tests/mistral_tool_use/test_mistral_tool_calls.py
Normal file
@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL
|
||||
|
||||
|
||||
# test: a tool_choice with mistral-tokenizer results in an ID of length 9
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_ASKING_FOR_TOOLS,
|
||||
temperature=0,
|
||||
max_completion_tokens=100,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL],
|
||||
tool_choice=WEATHER_TOOL,
|
||||
logprobs=False)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
|
||||
assert choice.finish_reason != "tool_calls" # "stop" or "length"
|
||||
assert choice.message.role == "assistant"
|
||||
assert choice.message.tool_calls is None \
|
||||
or len(choice.message.tool_calls) == 1
|
||||
assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral
|
33
tests/mistral_tool_use/utils.py
Normal file
33
tests/mistral_tool_use/utils.py
Normal file
@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class ServerConfig(TypedDict, total=False):
|
||||
model: str
|
||||
arguments: List[str]
|
||||
system_prompt: Optional[str]
|
||||
supports_parallel: Optional[bool]
|
||||
supports_rocm: Optional[bool]
|
||||
|
||||
|
||||
ARGS: List[str] = ["--max-model-len", "1024"]
|
||||
|
||||
CONFIGS: Dict[str, ServerConfig] = {
|
||||
"mistral": {
|
||||
"model":
|
||||
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"arguments": [
|
||||
"--tokenizer-mode", "mistral",
|
||||
"--ignore-patterns=\"consolidated.safetensors\""
|
||||
],
|
||||
"system_prompt":
|
||||
"You are a helpful assistant with access to tools. If a tool"
|
||||
" that you have would be helpful to answer a user query, "
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally."
|
||||
},
|
||||
}
|
@ -28,12 +28,15 @@ from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||
MistralToolCall)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
|
||||
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
|
||||
truncate_tool_call_ids)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -150,11 +153,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"tool_choice = \"required\" is not supported!")
|
||||
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||
maybe_serialize_tool_calls(request)
|
||||
truncate_tool_call_ids(request)
|
||||
|
||||
if (request.tool_choice == "auto" and
|
||||
not (self.enable_auto_tools and tool_parser is not None)
|
||||
@ -745,11 +749,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
elif request.tool_choice and type(
|
||||
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
|
||||
tool_call_class = MistralToolCall if isinstance(
|
||||
tokenizer, MistralTokenizer) else ToolCall
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(function=FunctionCall(
|
||||
tool_call_class(function=FunctionCall(
|
||||
name=request.tool_choice.function.name,
|
||||
arguments=output.text))
|
||||
])
|
||||
|
@ -33,7 +33,7 @@ class MistralToolCall(ToolCall):
|
||||
|
||||
@staticmethod
|
||||
def generate_random_id():
|
||||
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
|
||||
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
|
||||
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
|
||||
return "".join(choices(ALPHANUMERIC, k=9))
|
||||
|
||||
|
@ -1,5 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .mistral import MistralTokenizer, maybe_serialize_tool_calls
|
||||
from .mistral import (MistralTokenizer, maybe_serialize_tool_calls,
|
||||
truncate_tool_call_ids)
|
||||
|
||||
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]
|
||||
__all__ = [
|
||||
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids"
|
||||
]
|
||||
|
@ -68,6 +68,36 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
|
||||
request.messages[i]["tool_calls"] = validated_tool_calls
|
||||
|
||||
|
||||
def truncate_tool_call_ids(request: "ChatCompletionRequest"):
|
||||
"""Truncates tool call IDs for Mistral's ID requirements."""
|
||||
for i, message in enumerate(request.messages):
|
||||
if message.get("role") == 'assistant':
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
for tool_call in tool_calls:
|
||||
if len(tool_call["id"]) > 9:
|
||||
logger.warning(
|
||||
"Truncating tool call ID: %s to %s",
|
||||
tool_call["id"],
|
||||
tool_call["id"][-9:],
|
||||
)
|
||||
tool_call["id"] = tool_call["id"][-9:]
|
||||
|
||||
request.messages[i]["tool_calls"] = tool_calls
|
||||
|
||||
elif message.get("role") in {"tool_results", "tool"}:
|
||||
if "tool_call_id" in message:
|
||||
tool_call_id = message["tool_call_id"]
|
||||
|
||||
if len(tool_call_id) > 9:
|
||||
logger.warning(
|
||||
"Truncating tool_call_id: %s to %s",
|
||||
tool_call_id,
|
||||
tool_call_id[-9:],
|
||||
)
|
||||
tool_call_id = tool_call_id[-9:]
|
||||
request.messages[i]["tool_call_id"] = tool_call_id
|
||||
|
||||
|
||||
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
|
||||
repo_cache = os.path.join(
|
||||
huggingface_hub.constants.HF_HUB_CACHE,
|
||||
|
Loading…
x
Reference in New Issue
Block a user