[Frontend] Implement Tool Calling with tool_choice='required'
(#13483)
Signed-off-by: Liangfu Chen <liangfc@amazon.com> Signed-off-by: Matt, Matthias <matthias.matt@tuwien.ac.at> Co-authored-by: Liangfu Chen <liangfc@amazon.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
98d7367b61
commit
cefb9e5a28
@ -1,6 +1,6 @@
|
||||
# Tool Calling
|
||||
|
||||
vLLM currently supports named function calling, as well as the `auto` and `none` options for the `tool_choice` field in the chat completion API. The `tool_choice` option `required` is **not yet supported** but [on the roadmap](gh-issue:13002).
|
||||
vLLM currently supports named function calling, as well as the `auto`, `required` (as of `vllm>=0.8.3`) and `none` options for the `tool_choice` field in the chat completion API.
|
||||
|
||||
## Quickstart
|
||||
|
||||
@ -91,6 +91,12 @@ For best results, we recommend ensuring that the expected output format / schema
|
||||
To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and
|
||||
specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.
|
||||
|
||||
## Required Function Calling
|
||||
|
||||
vLLM supports the `tool_choice='required'` option in the chat completion API. Similar to the named function calling, it also uses guided decoding, so this is enabled by default and will work with any supported model. The required guided decoding features (JSON schema with `anyOf`) are currently only supported in the V0 engine with the guided decoding backend `outlines`. However, support for alternative decoding backends are on the [roadmap](https://docs.vllm.ai/en/latest/getting_started/v1_user_guide.html#feature-model) for the V1 engine.
|
||||
|
||||
When tool_choice='required' is set, the model is guaranteed to generate one or more tool calls based on the specified tool list in the `tools` parameter. The number of tool calls depends on the user's query. The output format strictly follows the schema defined in the `tools` parameter.
|
||||
|
||||
## Automatic Function Calling
|
||||
|
||||
To enable this feature, you should set the following flags:
|
||||
|
@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
To run this example, you can start the vLLM server
|
||||
without any specific flags:
|
||||
|
||||
```bash
|
||||
VLLM_USE_V1=0 vllm serve unsloth/Llama-3.2-1B-Instruct \
|
||||
--guided-decoding-backend outlines
|
||||
```
|
||||
|
||||
This example demonstrates how to generate chat completions
|
||||
using the OpenAI Python client library.
|
||||
"""
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to find the weather for"
|
||||
", e.g. 'San Francisco'",
|
||||
},
|
||||
"state": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"the two-letter abbreviation for the state that the "
|
||||
"city is in, e.g. 'CA' which would mean 'California'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_forecast",
|
||||
"description": "Get the weather forecast for a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to get the forecast for, e.g. 'New York'",
|
||||
},
|
||||
"state": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The two-letter abbreviation for the state, e.g. 'NY'",
|
||||
},
|
||||
"days": {
|
||||
"type":
|
||||
"integer",
|
||||
"description":
|
||||
"Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "days", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the current weather is in Dallas \
|
||||
and the forecast for the next 5 days, in fahrenheit?",
|
||||
},
|
||||
]
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
stream=True # Enable streaming response
|
||||
)
|
||||
|
||||
for chunk in chat_completion:
|
||||
if chunk.choices and chunk.choices[0].delta.tool_calls:
|
||||
print(chunk.choices[0].delta.tool_calls)
|
||||
|
||||
chat_completion = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
tool_choice="required")
|
||||
|
||||
print(chat_completion.choices[0].message.tool_calls)
|
@ -786,56 +786,135 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, is_v1_server: bool,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_required_tool_use_not_yet_supported(client: openai.AsyncOpenAI,
|
||||
sample_json_schema):
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_required_tool_use(client: openai.AsyncOpenAI,
|
||||
is_v1_server: bool, model_name: str):
|
||||
if is_v1_server:
|
||||
pytest.skip("sample_json_schema has features unsupported on V1")
|
||||
pytest.skip(
|
||||
"tool_choice='required' requires features unsupported on V1")
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {sample_json_schema}"
|
||||
}]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["country", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_forecast",
|
||||
"description": "Get the weather forecast for a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to get the forecast for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"days": {
|
||||
"type":
|
||||
"integer",
|
||||
"description":
|
||||
"Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["country", "days", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=1000,
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema
|
||||
}
|
||||
}],
|
||||
tool_choice="required")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the current weather is in Berlin and the "\
|
||||
"forecast for the next 5 days, in fahrenheit?",
|
||||
},
|
||||
]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=1000,
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema
|
||||
}
|
||||
}],
|
||||
tool_choice="auto")
|
||||
# Non-streaming test
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
extra_body=dict(guided_decoding_backend="outlines"),
|
||||
)
|
||||
|
||||
assert chat_completion.choices[0].message.tool_calls is not None
|
||||
assert len(chat_completion.choices[0].message.tool_calls) > 0
|
||||
|
||||
# Streaming test
|
||||
stream = await client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
extra_body=dict(guided_decoding_backend="outlines"),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
output = []
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.tool_calls:
|
||||
output.extend(chunk.choices[0].delta.tool_calls)
|
||||
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
|
||||
is_v1_server: bool,
|
||||
sample_json_schema):
|
||||
|
||||
if is_v1_server:
|
||||
|
@ -43,7 +43,8 @@ def test_chat_completion_request_with_no_tools():
|
||||
assert request.tool_choice == 'none'
|
||||
|
||||
|
||||
def test_chat_completion_request_with_tool_choice_but_no_tools():
|
||||
@pytest.mark.parametrize('tool_choice', ['auto', 'required'])
|
||||
def test_chat_completion_request_with_tool_choice_but_no_tools(tool_choice):
|
||||
with pytest.raises(ValueError,
|
||||
match="When using `tool_choice`, `tools` must be set."):
|
||||
ChatCompletionRequest.model_validate({
|
||||
@ -54,7 +55,7 @@ def test_chat_completion_request_with_tool_choice_but_no_tools():
|
||||
'model':
|
||||
'facebook/opt-125m',
|
||||
'tool_choice':
|
||||
'auto'
|
||||
tool_choice
|
||||
})
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
@ -67,7 +68,7 @@ def test_chat_completion_request_with_tool_choice_but_no_tools():
|
||||
'model':
|
||||
'facebook/opt-125m',
|
||||
'tool_choice':
|
||||
'auto',
|
||||
tool_choice,
|
||||
'tools':
|
||||
None
|
||||
})
|
||||
|
336
tests/tool_use/test_tool_choice_required.py
Normal file
336
tests/tool_use/test_tool_choice_required.py
Normal file
@ -0,0 +1,336 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import json
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionToolsParam)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
|
||||
EXAMPLE_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to find the weather for"
|
||||
", e.g. 'San Francisco'",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
"additionalProperties": False
|
||||
},
|
||||
},
|
||||
"strict": True
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_forecast",
|
||||
"description": "Get the weather forecast for a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to get the forecast for, e.g. 'New York'",
|
||||
},
|
||||
"days": {
|
||||
"type":
|
||||
"integer",
|
||||
"description":
|
||||
"Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
},
|
||||
"required": ["city", "days"],
|
||||
"additionalProperties": False
|
||||
},
|
||||
},
|
||||
"strict": True
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
|
||||
should_match: bool):
|
||||
self = MagicMock(tool_choice="required", tools=tools)
|
||||
schema = ChatCompletionRequest._get_guided_json_from_tool(self)
|
||||
assert isinstance(schema, dict)
|
||||
|
||||
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
|
||||
from outlines_core.fsm.json_schema import build_regex_from_schema
|
||||
regex = build_regex_from_schema(json.dumps(schema))
|
||||
compiled = re.compile(regex)
|
||||
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
|
||||
|
||||
assert matches == should_match
|
||||
|
||||
|
||||
VALID_TOOL_OUTPUTS = [
|
||||
([{
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}], True),
|
||||
([{
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}, {
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Berlin"
|
||||
}
|
||||
}], True),
|
||||
([{
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"days": 7
|
||||
}
|
||||
}], True),
|
||||
([{
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"days": 7
|
||||
}
|
||||
}, {
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}], True),
|
||||
([{
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"days": 7
|
||||
}
|
||||
}, {
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}, {
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Berlin",
|
||||
"days": 7
|
||||
}
|
||||
}, {
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Berlin"
|
||||
}
|
||||
}], True),
|
||||
]
|
||||
|
||||
VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_output, should_match",
|
||||
VALID_TOOL_OUTPUTS + [
|
||||
(None, False),
|
||||
([], False), # empty list cannot be generated
|
||||
({}, False), # empty object cannot be generated
|
||||
([{}], False), # list with empty object cannot be generated
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather"
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {}
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": None
|
||||
}],
|
||||
False),
|
||||
(
|
||||
{ # tool call without lists cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
},
|
||||
False),
|
||||
(
|
||||
[{ # tool call with extra parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"extra": "value"
|
||||
}
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # tool call where parameters are first cannot be generated
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
},
|
||||
"name": "get_current_weather"
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # tool call without all required parameters cannot be generated
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}],
|
||||
False),
|
||||
( # tool call with incorrect name/parameters cannot be generated
|
||||
[{
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"days": 7
|
||||
}
|
||||
}], False),
|
||||
( # tool call with both valid and empty function cannot be generated
|
||||
[{
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}, {}], False),
|
||||
])
|
||||
def test_guided_json(sample_output, should_match):
|
||||
_compile_and_check(tools=TypeAdapter(
|
||||
list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS),
|
||||
sample_output=sample_output,
|
||||
should_match=should_match)
|
||||
|
||||
|
||||
def update_parameters_none(
|
||||
tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
|
||||
tool.function.parameters = None
|
||||
return tool
|
||||
|
||||
|
||||
def update_parameters_empty_dict(
|
||||
tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
|
||||
tool.function.parameters = {}
|
||||
return tool
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_output, should_match",
|
||||
[
|
||||
(None, False),
|
||||
([], False), # empty list cannot be generated
|
||||
({}, False), # empty object cannot be generated
|
||||
([{}], False), # list with empty object cannot be generated
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather"
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": None
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # function with extra parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"extra": "value"
|
||||
}
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # only function with empty parameters object is valid
|
||||
"name": "get_current_weather",
|
||||
"parameters": {}
|
||||
}],
|
||||
True),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"update_parameters",
|
||||
[update_parameters_none, update_parameters_empty_dict])
|
||||
def test_guided_json_without_parameters(sample_output, should_match,
|
||||
update_parameters):
|
||||
updated_tools = [deepcopy(EXAMPLE_TOOLS[0])]
|
||||
tools = TypeAdapter(
|
||||
list[ChatCompletionToolsParam]).validate_python(updated_tools)
|
||||
tools = list(map(update_parameters, tools))
|
||||
assert all([
|
||||
tool.function.parameters is None or tool.function.parameters == {}
|
||||
for tool in tools
|
||||
])
|
||||
_compile_and_check(tools=tools,
|
||||
sample_output=sample_output,
|
||||
should_match=should_match)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("output", VALID_TOOLS)
|
||||
@pytest.mark.parametrize("empty_params", [False, True])
|
||||
@pytest.mark.parametrize("delta_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
def test_streaming_output_valid(output, empty_params, delta_len):
|
||||
self = MagicMock()
|
||||
|
||||
output = deepcopy(output)
|
||||
if empty_params:
|
||||
output = [{"name": o["name"], "parameters": {}} for o in output]
|
||||
output_json = json.dumps(output)
|
||||
|
||||
previous_text = ""
|
||||
function_name_returned = False
|
||||
messages = []
|
||||
for i in range(0, len(output_json), delta_len):
|
||||
delta_text = output_json[i:i + delta_len]
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message, function_name_returned = (
|
||||
OpenAIServingChat.extract_tool_call_required_streaming(
|
||||
self,
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=function_name_returned))
|
||||
|
||||
if delta_message:
|
||||
messages.append(delta_message)
|
||||
|
||||
previous_text = current_text
|
||||
|
||||
assert len(messages) > 0
|
||||
combined_messages = "["
|
||||
for message in messages:
|
||||
if message.tool_calls[0].function.name:
|
||||
if len(combined_messages) > 1:
|
||||
combined_messages += "},"
|
||||
|
||||
combined_messages += '{"name": "' + \
|
||||
message.tool_calls[0].function.name + \
|
||||
'", "parameters": ' + \
|
||||
message.tool_calls[0].function.arguments
|
||||
else:
|
||||
combined_messages += message.tool_calls[0].function.arguments
|
||||
combined_messages += "}]"
|
||||
assert json.loads(combined_messages) == output
|
||||
assert json.dumps(json.loads(combined_messages)) == output_json
|
@ -61,7 +61,7 @@ class OpenAIBaseModel(BaseModel):
|
||||
field_names = set()
|
||||
for field_name, field in cls.model_fields.items():
|
||||
field_names.add(field_name)
|
||||
if alias := getattr(field, 'alias', None):
|
||||
if alias := getattr(field, "alias", None):
|
||||
field_names.add(alias)
|
||||
cls.field_names = field_names
|
||||
|
||||
@ -70,7 +70,8 @@ class OpenAIBaseModel(BaseModel):
|
||||
logger.warning(
|
||||
"The following fields were present in the request "
|
||||
"but ignored: %s",
|
||||
data.keys() - field_names)
|
||||
data.keys() - field_names,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@ -234,8 +235,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
tools: Optional[list[ChatCompletionToolsParam]] = None
|
||||
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
|
||||
ChatCompletionNamedToolChoiceParam]] = "none"
|
||||
tool_choice: Optional[Union[
|
||||
Literal["none"],
|
||||
Literal["auto"],
|
||||
Literal["required"],
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
]] = "none"
|
||||
|
||||
# NOTE this will be ignored by vLLM -- the model determines the behavior
|
||||
parallel_tool_calls: Optional[bool] = False
|
||||
@ -340,24 +345,28 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be either "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
"'outlines' / 'lm-format-enforcer'"),
|
||||
)
|
||||
guided_whitespace_pattern: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default whitespace pattern "
|
||||
"for guided json decoding."))
|
||||
"for guided json decoding."),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."))
|
||||
"through out the inference process and return in response."),
|
||||
)
|
||||
logits_processors: Optional[LogitsProcessors] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@ -415,13 +424,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None,
|
||||
) -> SamplingParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
|
||||
@ -475,7 +486,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
grammar=self.guided_grammar,
|
||||
json_object=guided_json_object,
|
||||
backend=self.guided_decoding_backend,
|
||||
whitespace_pattern=self.guided_whitespace_pattern)
|
||||
whitespace_pattern=self.guided_whitespace_pattern,
|
||||
)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
@ -522,6 +534,41 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
tool = tools[tool_name]
|
||||
return tool.parameters
|
||||
|
||||
if self.tool_choice == "required":
|
||||
# Pydantic schema generation cannot be used since the JSON schema
|
||||
# has to be constructed for a specific instantiation of a tool list
|
||||
# so that parameters of a function are correctly generated
|
||||
# based on the chosen function name
|
||||
def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
|
||||
return {
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"enum": [tool.function.name]
|
||||
},
|
||||
# parameters are always generated as '{}' in the final
|
||||
# output if they are missing from the request
|
||||
# (i.e. are None or '{}') so the schema is
|
||||
# updated to produce an empty object in that case
|
||||
"parameters": tool.function.parameters
|
||||
if tool.function.parameters else {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
},
|
||||
"required": ["name", "parameters"]
|
||||
}
|
||||
|
||||
json_schema = {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"anyOf": [get_tool_schema(tool) for tool in self.tools]
|
||||
}
|
||||
}
|
||||
return json_schema
|
||||
|
||||
return None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@ -572,8 +619,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"You can only use one kind of guided decoding "
|
||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||
# you can only either use guided decoding or tools, not both
|
||||
if guide_count > 1 and data.get("tool_choice",
|
||||
"none") not in ("none", "auto"):
|
||||
if guide_count > 1 and data.get("tool_choice", "none") not in (
|
||||
"none",
|
||||
"auto",
|
||||
"required",
|
||||
):
|
||||
raise ValueError(
|
||||
"You can only either use guided decoding or tools, not both.")
|
||||
return data
|
||||
@ -602,12 +652,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"When using `tool_choice`, `tools` must be set.")
|
||||
|
||||
# make sure that tool choice is either a named tool
|
||||
# OR that it's set to "auto"
|
||||
if data["tool_choice"] != "auto" and not isinstance(
|
||||
data["tool_choice"], dict):
|
||||
raise ValueError(
|
||||
"`tool_choice` must either be a named tool, \"auto\", "
|
||||
"or \"none\".")
|
||||
# OR that it's set to "auto" or "required"
|
||||
if data["tool_choice"] not in [
|
||||
"auto", "required"
|
||||
] and not isinstance(data["tool_choice"], dict):
|
||||
raise NotImplementedError(
|
||||
f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\
|
||||
'Only named tools, "none", "auto" or "required" '\
|
||||
'are supported.'
|
||||
)
|
||||
|
||||
# ensure that if "tool_choice" is specified as an object,
|
||||
# it matches a valid tool
|
||||
@ -722,18 +775,21 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be one of "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
"'outlines' / 'lm-format-enforcer'"),
|
||||
)
|
||||
guided_whitespace_pattern: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default whitespace pattern "
|
||||
"for guided json decoding."))
|
||||
"for guided json decoding."),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
logits_processors: Optional[LogitsProcessors] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@ -745,6 +801,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
"arguments. For example: {'qualname': "
|
||||
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
|
||||
"{'param': 'value'}}."))
|
||||
|
||||
return_tokens_as_token_ids: Optional[bool] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@ -789,13 +846,15 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None,
|
||||
) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
@ -844,7 +903,8 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
grammar=self.guided_grammar,
|
||||
json_object=guided_json_object,
|
||||
backend=self.guided_decoding_backend,
|
||||
whitespace_pattern=self.guided_whitespace_pattern)
|
||||
whitespace_pattern=self.guided_whitespace_pattern,
|
||||
)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
@ -942,7 +1002,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
|
||||
# doc: end-embedding-extra-params
|
||||
|
||||
@ -995,7 +1056,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
# doc: end-chat-embedding-extra-params
|
||||
|
||||
@model_validator(mode="before")
|
||||
@ -1034,7 +1096,8 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
|
||||
# doc: end-score-extra-params
|
||||
|
||||
@ -1059,7 +1122,8 @@ class RerankRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
|
||||
# doc: end-rerank-extra-params
|
||||
|
||||
|
@ -2,13 +2,16 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Callable, Final, Optional, Union
|
||||
|
||||
import jinja2
|
||||
import partial_json_parser
|
||||
from fastapi import Request
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@ -21,8 +24,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
|
||||
RequestResponseMetadata, ToolCall, UsageInfo)
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition,
|
||||
PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
clamp_prompt_logprobs)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
@ -150,12 +153,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
tool_parser = self.tool_parser
|
||||
|
||||
# validation for OpenAI tools
|
||||
# tool_choice = "required" is not supported
|
||||
if request.tool_choice == "required":
|
||||
return self.create_error_response(
|
||||
"tool_choice = \"required\" is not supported!")
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
@ -277,6 +274,122 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return self.response_role
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
@staticmethod
|
||||
def _bracket_level(s: str, opening='{', closing='}') -> int:
|
||||
"""
|
||||
Calculate the current level of nested brackets in a given string.
|
||||
"""
|
||||
level = 0
|
||||
for char in s:
|
||||
if char == opening:
|
||||
level += 1
|
||||
elif char == closing:
|
||||
level -= 1
|
||||
return level
|
||||
|
||||
@staticmethod
|
||||
def _filter_delta_text(delta_text: str,
|
||||
previous_text: str) -> tuple[str, bool]:
|
||||
# remove last '},' of the tool definition stemming from the
|
||||
# "name"/"parameters" outer object or closing ']' of the tool list
|
||||
# count occurrences of opening and closing curly braces and
|
||||
# once level 0 is reached stop outputting text
|
||||
# if 0 is reached while parsing the delta_text we know the current
|
||||
# tool will finish in this current iteration
|
||||
bracket_level = OpenAIServingChat._bracket_level(previous_text)
|
||||
updated_delta, passed_zero = "", False
|
||||
for c in delta_text:
|
||||
if c == '{':
|
||||
bracket_level += 1
|
||||
passed_zero = bracket_level == 0
|
||||
elif c == '}':
|
||||
bracket_level -= 1
|
||||
passed_zero = bracket_level == 0
|
||||
|
||||
if bracket_level != 0:
|
||||
updated_delta += c
|
||||
else:
|
||||
# if a comma is reached at level 0 we can stop
|
||||
if c == ',':
|
||||
break
|
||||
return updated_delta, passed_zero
|
||||
|
||||
def extract_tool_call_required_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
function_name_returned: bool,
|
||||
) -> tuple[Optional[DeltaMessage], bool]:
|
||||
try:
|
||||
obj = partial_json_parser.loads(current_text)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
obj = None
|
||||
|
||||
# check if the current text is a valid array
|
||||
# containing a partial tool calling object
|
||||
# if not repeat
|
||||
if obj is None or not isinstance(obj, list) or not len(obj) > 0:
|
||||
function_name_returned = False
|
||||
delta_message = None
|
||||
else:
|
||||
_, finishes_previous_tool = OpenAIServingChat._filter_delta_text(
|
||||
delta_text, previous_text)
|
||||
# take the last tool call from the generated list
|
||||
current_tool_call = obj[-1]
|
||||
|
||||
# once parameters have been generated the name is complete as well
|
||||
if not finishes_previous_tool and ("name" not in current_tool_call
|
||||
or "parameters"
|
||||
not in current_tool_call):
|
||||
function_name_returned = False
|
||||
delta_message = None
|
||||
else:
|
||||
if not function_name_returned:
|
||||
# get partly generated arguments from the latest tool call
|
||||
param_match = re.search(r'.*"parameters":\s*(.*)',
|
||||
current_text)
|
||||
arguments = param_match.group(1) if param_match else ""
|
||||
arguments, _ = OpenAIServingChat._filter_delta_text(
|
||||
arguments, previous_text)
|
||||
|
||||
# if this iteration finishes a previous tool call but a
|
||||
# new incomplete tool is already generated, take the
|
||||
# previous from the list
|
||||
if (finishes_previous_tool
|
||||
and "parameters" not in current_tool_call):
|
||||
current_tool_call = obj[-2]
|
||||
|
||||
function_name_returned = True
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(function=DeltaFunctionCall(
|
||||
name=current_tool_call["name"],
|
||||
arguments=arguments),
|
||||
index=len(obj) - 1,
|
||||
type="function")
|
||||
])
|
||||
|
||||
else:
|
||||
delta_text, _ = OpenAIServingChat._filter_delta_text(
|
||||
delta_text, previous_text)
|
||||
|
||||
if delta_text != "":
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
function=DeltaFunctionCall(
|
||||
# OpenAI API returns None
|
||||
# instead of name every time
|
||||
name=None,
|
||||
arguments=delta_text),
|
||||
index=len(obj) - 1,
|
||||
type="function")
|
||||
])
|
||||
else:
|
||||
delta_message = None
|
||||
|
||||
return delta_message, function_name_returned
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
@ -312,6 +425,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self._should_stream_with_reasoning_parsing(request))
|
||||
|
||||
all_previous_token_ids: Optional[list[list[int]]]
|
||||
function_name_returned: Optional[list[bool]] = None
|
||||
|
||||
# Only one of these will be used, thus previous_texts and
|
||||
# all_previous_token_ids will not be used twice in the same iteration.
|
||||
@ -322,6 +436,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# For reasoning parser and tool call all enabled
|
||||
added_content_delta_arr = [False] * num_choices
|
||||
reasoning_end_arr = [False] * num_choices
|
||||
elif request.tool_choice == "required":
|
||||
previous_texts = [""] * num_choices
|
||||
function_name_returned = [False] * num_choices
|
||||
all_previous_token_ids = None
|
||||
else:
|
||||
previous_texts, all_previous_token_ids = None, None
|
||||
|
||||
@ -521,6 +639,23 @@ class OpenAIServingChat(OpenAIServing):
|
||||
index=i)
|
||||
])
|
||||
|
||||
elif request.tool_choice == "required":
|
||||
assert previous_texts is not None
|
||||
assert function_name_returned is not None
|
||||
previous_text = previous_texts[i]
|
||||
current_text = previous_text + delta_text
|
||||
fn_name_returned = function_name_returned[i]
|
||||
|
||||
delta_message, function_name_returned[i] = (
|
||||
self.extract_tool_call_required_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=fn_name_returned))
|
||||
|
||||
# update the previous values for the next iteration
|
||||
previous_texts[i] = current_text
|
||||
|
||||
# handle streaming deltas for tools with "auto" tool choice
|
||||
# and reasoning parser
|
||||
elif tool_choice_auto and self.enable_reasoning:
|
||||
@ -821,10 +956,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
# if auto tools are not enabled, and a named tool choice using
|
||||
# outlines is not being used
|
||||
if (not self.enable_auto_tools
|
||||
or not self.tool_parser) and not isinstance(
|
||||
request.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam):
|
||||
if (not self.enable_auto_tools or not self.tool_parser) and \
|
||||
(not isinstance(request.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam
|
||||
) and request.tool_choice != "required"):
|
||||
message = ChatMessage(role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=content)
|
||||
@ -845,6 +980,24 @@ class OpenAIServingChat(OpenAIServing):
|
||||
arguments=content))
|
||||
])
|
||||
|
||||
elif request.tool_choice and request.tool_choice == "required":
|
||||
tool_call_class = MistralToolCall if isinstance(
|
||||
tokenizer, MistralTokenizer) else ToolCall
|
||||
|
||||
# the fields of FunctionDefinition are a superset of the
|
||||
# tool call outputs and can be used for parsing
|
||||
tool_calls = TypeAdapter(
|
||||
list[FunctionDefinition]).validate_json(output.text)
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
content="",
|
||||
tool_calls=[
|
||||
tool_call_class(function=FunctionCall(
|
||||
name=tool_call.name,
|
||||
arguments=json.dumps(tool_call.parameters)))
|
||||
for tool_call in tool_calls
|
||||
])
|
||||
|
||||
# if the request doesn't use tool choice
|
||||
# OR specifies to not use a tool
|
||||
elif not request.tool_choice or request.tool_choice == "none":
|
||||
|
Loading…
x
Reference in New Issue
Block a user