337 lines
10 KiB
Python
337 lines
10 KiB
Python
![]() |
# SPDX-License-Identifier: Apache-2.0
|
||
|
import json
|
||
|
import re
|
||
|
from copy import deepcopy
|
||
|
from unittest.mock import MagicMock
|
||
|
|
||
|
import pytest
|
||
|
from pydantic import TypeAdapter
|
||
|
|
||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||
|
ChatCompletionToolsParam)
|
||
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||
|
|
||
|
EXAMPLE_TOOLS = [
|
||
|
{
|
||
|
"type": "function",
|
||
|
"function": {
|
||
|
"name": "get_current_weather",
|
||
|
"description": "Get the current weather in a given location",
|
||
|
"parameters": {
|
||
|
"type": "object",
|
||
|
"properties": {
|
||
|
"city": {
|
||
|
"type":
|
||
|
"string",
|
||
|
"description":
|
||
|
"The city to find the weather for"
|
||
|
", e.g. 'San Francisco'",
|
||
|
},
|
||
|
},
|
||
|
"required": ["city"],
|
||
|
"additionalProperties": False
|
||
|
},
|
||
|
},
|
||
|
"strict": True
|
||
|
},
|
||
|
{
|
||
|
"type": "function",
|
||
|
"function": {
|
||
|
"name": "get_forecast",
|
||
|
"description": "Get the weather forecast for a given location",
|
||
|
"parameters": {
|
||
|
"type": "object",
|
||
|
"properties": {
|
||
|
"city": {
|
||
|
"type":
|
||
|
"string",
|
||
|
"description":
|
||
|
"The city to get the forecast for, e.g. 'New York'",
|
||
|
},
|
||
|
"days": {
|
||
|
"type":
|
||
|
"integer",
|
||
|
"description":
|
||
|
"Number of days to get the forecast for (1-7)",
|
||
|
},
|
||
|
},
|
||
|
"required": ["city", "days"],
|
||
|
"additionalProperties": False
|
||
|
},
|
||
|
},
|
||
|
"strict": True
|
||
|
},
|
||
|
]
|
||
|
|
||
|
|
||
|
def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
|
||
|
should_match: bool):
|
||
|
self = MagicMock(tool_choice="required", tools=tools)
|
||
|
schema = ChatCompletionRequest._get_guided_json_from_tool(self)
|
||
|
assert isinstance(schema, dict)
|
||
|
|
||
|
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
|
||
|
from outlines_core.fsm.json_schema import build_regex_from_schema
|
||
|
regex = build_regex_from_schema(json.dumps(schema))
|
||
|
compiled = re.compile(regex)
|
||
|
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
|
||
|
|
||
|
assert matches == should_match
|
||
|
|
||
|
|
||
|
VALID_TOOL_OUTPUTS = [
|
||
|
([{
|
||
|
"name": "get_current_weather",
|
||
|
"parameters": {
|
||
|
"city": "Vienna"
|
||
|
}
|
||
|
}], True),
|
||
|
([{
|
||
|
"name": "get_current_weather",
|
||
|
"parameters": {
|
||
|
"city": "Vienna"
|
||
|
}
|
||
|
}, {
|
||
|
"name": "get_current_weather",
|
||
|
"parameters": {
|
||
|
"city": "Berlin"
|
||
|
}
|
||
|
}], True),
|
||
|
([{
|
||
|
"name": "get_forecast",
|
||
|
"parameters": {
|
||
|
"city": "Vienna",
|
||
|
"days": 7
|
||
|
}
|
||
|
}], True),
|
||
|
([{
|
||
|
"name": "get_forecast",
|
||
|
"parameters": {
|
||
|
"city": "Vienna",
|
||
|
"days": 7
|
||
|
}
|
||
|
}, {
|
||
|
"name": "get_current_weather",
|
||
|
"parameters": {
|
||
|
"city": "Vienna"
|
||
|
}
|
||
|
}], True),
|
||
|
([{
|
||
|
"name": "get_forecast",
|
||
|
"parameters": {
|
||
|
"city": "Vienna",
|
||
|
"days": 7
|
||
|
}
|
||
|
}, {
|
||
|
"name": "get_current_weather",
|
||
|
"parameters": {
|
||
|
"city": "Vienna"
|
||
|
}
|
||
|
}, {
|
||
|
"name": "get_forecast",
|
||
|
"parameters": {
|
||
|
"city": "Berlin",
|
||
|
"days": 7
|
||
|
}
|
||
|
}, {
|
||
|
"name": "get_current_weather",
|
||
|
"parameters": {
|
||
|
"city": "Berlin"
|
||
|
}
|
||
|
}], True),
|
||
|
]
|
||
|
|
||
|
VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"sample_output, should_match",
|
||
|
VALID_TOOL_OUTPUTS + [
|
||
|
(None, False),
|
||
|
([], False), # empty list cannot be generated
|
||
|
({}, False), # empty object cannot be generated
|
||
|
([{}], False), # list with empty object cannot be generated
|
||
|
(
|
||
|
[{ # function without required parameters cannot be generated
|
||
|
"name": "get_current_weather"
|
||
|
}],
|
||
|
False),
|
||
|
(
|
||
|
[{ # function without required parameters cannot be generated
|
||
|
"name": "get_current_weather",
|
||
|
"parameters": {}
|
||
|
}],
|
||
|
False),
|
||
|
(
|
||
|
[{ # function without required parameters cannot be generated
|
||
|
"name": "get_current_weather",
|
||
|
"parameters": None
|
||
|
}],
|
||
|
False),
|
||
|
(
|
||
|
{ # tool call without lists cannot be generated
|
||
|
"name": "get_current_weather",
|
||
|
"parameters": {
|
||
|
"city": "Vienna"
|
||
|
}
|
||
|
},
|
||
|
False),
|
||
|
(
|
||
|
[{ # tool call with extra parameters cannot be generated
|
||
|
"name": "get_current_weather",
|
||
|
"parameters": {
|
||
|
"city": "Vienna",
|
||
|
"extra": "value"
|
||
|
}
|
||
|
}],
|
||
|
False),
|
||
|
(
|
||
|
[{ # tool call where parameters are first cannot be generated
|
||
|
"parameters": {
|
||
|
"city": "Vienna"
|
||
|
},
|
||
|
"name": "get_current_weather"
|
||
|
}],
|
||
|
False),
|
||
|
(
|
||
|
[{ # tool call without all required parameters cannot be generated
|
||
|
"name": "get_forecast",
|
||
|
"parameters": {
|
||
|
"city": "Vienna"
|
||
|
}
|
||
|
}],
|
||
|
False),
|
||
|
( # tool call with incorrect name/parameters cannot be generated
|
||
|
[{
|
||
|
"name": "get_weather",
|
||
|
"parameters": {
|
||
|
"city": "Vienna",
|
||
|
"days": 7
|
||
|
}
|
||
|
}], False),
|
||
|
( # tool call with both valid and empty function cannot be generated
|
||
|
[{
|
||
|
"name": "get_current_weather",
|
||
|
"parameters": {
|
||
|
"city": "Vienna"
|
||
|
}
|
||
|
}, {}], False),
|
||
|
])
|
||
|
def test_guided_json(sample_output, should_match):
|
||
|
_compile_and_check(tools=TypeAdapter(
|
||
|
list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS),
|
||
|
sample_output=sample_output,
|
||
|
should_match=should_match)
|
||
|
|
||
|
|
||
|
def update_parameters_none(
|
||
|
tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
|
||
|
tool.function.parameters = None
|
||
|
return tool
|
||
|
|
||
|
|
||
|
def update_parameters_empty_dict(
|
||
|
tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
|
||
|
tool.function.parameters = {}
|
||
|
return tool
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"sample_output, should_match",
|
||
|
[
|
||
|
(None, False),
|
||
|
([], False), # empty list cannot be generated
|
||
|
({}, False), # empty object cannot be generated
|
||
|
([{}], False), # list with empty object cannot be generated
|
||
|
(
|
||
|
[{ # function without required parameters cannot be generated
|
||
|
"name": "get_current_weather"
|
||
|
}],
|
||
|
False),
|
||
|
(
|
||
|
[{ # function without required parameters cannot be generated
|
||
|
"name": "get_current_weather",
|
||
|
"parameters": None
|
||
|
}],
|
||
|
False),
|
||
|
(
|
||
|
[{ # function with extra parameters cannot be generated
|
||
|
"name": "get_current_weather",
|
||
|
"parameters": {
|
||
|
"extra": "value"
|
||
|
}
|
||
|
}],
|
||
|
False),
|
||
|
(
|
||
|
[{ # only function with empty parameters object is valid
|
||
|
"name": "get_current_weather",
|
||
|
"parameters": {}
|
||
|
}],
|
||
|
True),
|
||
|
])
|
||
|
@pytest.mark.parametrize(
|
||
|
"update_parameters",
|
||
|
[update_parameters_none, update_parameters_empty_dict])
|
||
|
def test_guided_json_without_parameters(sample_output, should_match,
|
||
|
update_parameters):
|
||
|
updated_tools = [deepcopy(EXAMPLE_TOOLS[0])]
|
||
|
tools = TypeAdapter(
|
||
|
list[ChatCompletionToolsParam]).validate_python(updated_tools)
|
||
|
tools = list(map(update_parameters, tools))
|
||
|
assert all([
|
||
|
tool.function.parameters is None or tool.function.parameters == {}
|
||
|
for tool in tools
|
||
|
])
|
||
|
_compile_and_check(tools=tools,
|
||
|
sample_output=sample_output,
|
||
|
should_match=should_match)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("output", VALID_TOOLS)
|
||
|
@pytest.mark.parametrize("empty_params", [False, True])
|
||
|
@pytest.mark.parametrize("delta_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||
|
def test_streaming_output_valid(output, empty_params, delta_len):
|
||
|
self = MagicMock()
|
||
|
|
||
|
output = deepcopy(output)
|
||
|
if empty_params:
|
||
|
output = [{"name": o["name"], "parameters": {}} for o in output]
|
||
|
output_json = json.dumps(output)
|
||
|
|
||
|
previous_text = ""
|
||
|
function_name_returned = False
|
||
|
messages = []
|
||
|
for i in range(0, len(output_json), delta_len):
|
||
|
delta_text = output_json[i:i + delta_len]
|
||
|
current_text = previous_text + delta_text
|
||
|
|
||
|
delta_message, function_name_returned = (
|
||
|
OpenAIServingChat.extract_tool_call_required_streaming(
|
||
|
self,
|
||
|
previous_text=previous_text,
|
||
|
current_text=current_text,
|
||
|
delta_text=delta_text,
|
||
|
function_name_returned=function_name_returned))
|
||
|
|
||
|
if delta_message:
|
||
|
messages.append(delta_message)
|
||
|
|
||
|
previous_text = current_text
|
||
|
|
||
|
assert len(messages) > 0
|
||
|
combined_messages = "["
|
||
|
for message in messages:
|
||
|
if message.tool_calls[0].function.name:
|
||
|
if len(combined_messages) > 1:
|
||
|
combined_messages += "},"
|
||
|
|
||
|
combined_messages += '{"name": "' + \
|
||
|
message.tool_calls[0].function.name + \
|
||
|
'", "parameters": ' + \
|
||
|
message.tool_calls[0].function.arguments
|
||
|
else:
|
||
|
combined_messages += message.tool_calls[0].function.arguments
|
||
|
combined_messages += "}]"
|
||
|
assert json.loads(combined_messages) == output
|
||
|
assert json.dumps(json.loads(combined_messages)) == output_json
|