[Frontend] Implement Tool Calling with tool_choice='required' (#13483)

Signed-off-by: Liangfu Chen <liangfc@amazon.com>
Signed-off-by: Matt, Matthias <matthias.matt@tuwien.ac.at>
Co-authored-by: Liangfu Chen <liangfc@amazon.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Matthias Matt 2025-04-02 16:45:45 +02:00 committed by GitHub
parent 98d7367b61
commit cefb9e5a28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 868 additions and 93 deletions

View File

@ -1,6 +1,6 @@
# Tool Calling # Tool Calling
vLLM currently supports named function calling, as well as the `auto` and `none` options for the `tool_choice` field in the chat completion API. The `tool_choice` option `required` is **not yet supported** but [on the roadmap](gh-issue:13002). vLLM currently supports named function calling, as well as the `auto`, `required` (as of `vllm>=0.8.3`) and `none` options for the `tool_choice` field in the chat completion API.
## Quickstart ## Quickstart
@ -91,6 +91,12 @@ For best results, we recommend ensuring that the expected output format / schema
To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and
specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request. specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.
## Required Function Calling
vLLM supports the `tool_choice='required'` option in the chat completion API. Similar to the named function calling, it also uses guided decoding, so this is enabled by default and will work with any supported model. The required guided decoding features (JSON schema with `anyOf`) are currently only supported in the V0 engine with the guided decoding backend `outlines`. However, support for alternative decoding backends are on the [roadmap](https://docs.vllm.ai/en/latest/getting_started/v1_user_guide.html#feature-model) for the V1 engine.
When tool_choice='required' is set, the model is guaranteed to generate one or more tool calls based on the specified tool list in the `tools` parameter. The number of tool calls depends on the user's query. The output format strictly follows the schema defined in the `tools` parameter.
## Automatic Function Calling ## Automatic Function Calling
To enable this feature, you should set the following flags: To enable this feature, you should set the following flags:

View File

@ -0,0 +1,136 @@
# SPDX-License-Identifier: Apache-2.0
"""
To run this example, you can start the vLLM server
without any specific flags:
```bash
VLLM_USE_V1=0 vllm serve unsloth/Llama-3.2-1B-Instruct \
--guided-decoding-backend outlines
```
This example demonstrates how to generate chat completions
using the OpenAI Python client library.
"""
from openai import OpenAI
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for"
", e.g. 'San Francisco'",
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the "
"city is in, e.g. 'CA' which would mean 'California'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
},
{
"type": "function",
"function": {
"name": "get_forecast",
"description": "Get the weather forecast for a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to get the forecast for, e.g. 'New York'",
},
"state": {
"type":
"string",
"description":
"The two-letter abbreviation for the state, e.g. 'NY'",
},
"days": {
"type":
"integer",
"description":
"Number of days to get the forecast for (1-7)",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "days", "unit"],
},
},
},
]
messages = [
{
"role": "user",
"content": "Hi! How are you doing today?"
},
{
"role": "assistant",
"content": "I'm doing well! How can I help you?"
},
{
"role":
"user",
"content":
"Can you tell me what the current weather is in Dallas \
and the forecast for the next 5 days, in fahrenheit?",
},
]
chat_completion = client.chat.completions.create(
messages=messages,
model=model,
tools=tools,
tool_choice="required",
stream=True # Enable streaming response
)
for chunk in chat_completion:
if chunk.choices and chunk.choices[0].delta.tool_calls:
print(chunk.choices[0].delta.tool_calls)
chat_completion = client.chat.completions.create(messages=messages,
model=model,
tools=tools,
tool_choice="required")
print(chat_completion.choices[0].message.tool_calls)

View File

@ -786,56 +786,135 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, is_v1_server: bool,
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_required_tool_use_not_yet_supported(client: openai.AsyncOpenAI, @pytest.mark.parametrize("model_name", [MODEL_NAME])
sample_json_schema): async def test_required_tool_use(client: openai.AsyncOpenAI,
is_v1_server: bool, model_name: str):
if is_v1_server: if is_v1_server:
pytest.skip("sample_json_schema has features unsupported on V1") pytest.skip(
"tool_choice='required' requires features unsupported on V1")
messages = [{ tools = [
"role": "system", {
"content": "you are a helpful assistant" "type": "function",
}, { "function": {
"role": "name": "get_current_weather",
"user", "description": "Get the current weather in a given location",
"content": "parameters": {
f"Give an example JSON for an employee profile that " "type": "object",
f"fits this schema: {sample_json_schema}" "properties": {
}] "city": {
"type": "string",
"description":
"The city to find the weather for, e.g. 'Vienna'",
"default": "Vienna",
},
"country": {
"type":
"string",
"description":
"The country that the city is in, e.g. 'Austria'",
},
"unit": {
"type": "string",
"description":
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["country", "unit"],
},
},
},
{
"type": "function",
"function": {
"name": "get_forecast",
"description": "Get the weather forecast for a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description":
"The city to get the forecast for, e.g. 'Vienna'",
"default": "Vienna",
},
"country": {
"type":
"string",
"description":
"The country that the city is in, e.g. 'Austria'",
},
"days": {
"type":
"integer",
"description":
"Number of days to get the forecast for (1-7)",
},
"unit": {
"type": "string",
"description":
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["country", "days", "unit"],
},
},
},
]
with pytest.raises(openai.BadRequestError): messages = [
await client.chat.completions.create( {
model=MODEL_NAME, "role": "user",
messages=messages, "content": "Hi! How are you doing today?"
max_completion_tokens=1000, },
tools=[{ {
"type": "function", "role": "assistant",
"function": { "content": "I'm doing well! How can I help you?"
"name": "dummy_function_name", },
"description": "This is a dummy function", {
"parameters": sample_json_schema "role":
} "user",
}], "content":
tool_choice="required") "Can you tell me what the current weather is in Berlin and the "\
"forecast for the next 5 days, in fahrenheit?",
},
]
with pytest.raises(openai.BadRequestError): # Non-streaming test
await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=MODEL_NAME, messages=messages,
messages=messages, model=model_name,
max_completion_tokens=1000, tools=tools,
tools=[{ tool_choice="required",
"type": "function", extra_body=dict(guided_decoding_backend="outlines"),
"function": { )
"name": "dummy_function_name",
"description": "This is a dummy function", assert chat_completion.choices[0].message.tool_calls is not None
"parameters": sample_json_schema assert len(chat_completion.choices[0].message.tool_calls) > 0
}
}], # Streaming test
tool_choice="auto") stream = await client.chat.completions.create(
messages=messages,
model=model_name,
tools=tools,
tool_choice="required",
extra_body=dict(guided_decoding_backend="outlines"),
stream=True,
)
output = []
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.tool_calls:
output.extend(chunk.choices[0].delta.tool_calls)
assert len(output) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI, async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
is_v1_server: bool,
sample_json_schema): sample_json_schema):
if is_v1_server: if is_v1_server:

View File

@ -43,7 +43,8 @@ def test_chat_completion_request_with_no_tools():
assert request.tool_choice == 'none' assert request.tool_choice == 'none'
def test_chat_completion_request_with_tool_choice_but_no_tools(): @pytest.mark.parametrize('tool_choice', ['auto', 'required'])
def test_chat_completion_request_with_tool_choice_but_no_tools(tool_choice):
with pytest.raises(ValueError, with pytest.raises(ValueError,
match="When using `tool_choice`, `tools` must be set."): match="When using `tool_choice`, `tools` must be set."):
ChatCompletionRequest.model_validate({ ChatCompletionRequest.model_validate({
@ -54,7 +55,7 @@ def test_chat_completion_request_with_tool_choice_but_no_tools():
'model': 'model':
'facebook/opt-125m', 'facebook/opt-125m',
'tool_choice': 'tool_choice':
'auto' tool_choice
}) })
with pytest.raises(ValueError, with pytest.raises(ValueError,
@ -67,7 +68,7 @@ def test_chat_completion_request_with_tool_choice_but_no_tools():
'model': 'model':
'facebook/opt-125m', 'facebook/opt-125m',
'tool_choice': 'tool_choice':
'auto', tool_choice,
'tools': 'tools':
None None
}) })

View File

@ -0,0 +1,336 @@
# SPDX-License-Identifier: Apache-2.0
import json
import re
from copy import deepcopy
from unittest.mock import MagicMock
import pytest
from pydantic import TypeAdapter
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
EXAMPLE_TOOLS = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for"
", e.g. 'San Francisco'",
},
},
"required": ["city"],
"additionalProperties": False
},
},
"strict": True
},
{
"type": "function",
"function": {
"name": "get_forecast",
"description": "Get the weather forecast for a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to get the forecast for, e.g. 'New York'",
},
"days": {
"type":
"integer",
"description":
"Number of days to get the forecast for (1-7)",
},
},
"required": ["city", "days"],
"additionalProperties": False
},
},
"strict": True
},
]
def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
should_match: bool):
self = MagicMock(tool_choice="required", tools=tools)
schema = ChatCompletionRequest._get_guided_json_from_tool(self)
assert isinstance(schema, dict)
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
from outlines_core.fsm.json_schema import build_regex_from_schema
regex = build_regex_from_schema(json.dumps(schema))
compiled = re.compile(regex)
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
assert matches == should_match
VALID_TOOL_OUTPUTS = [
([{
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}], True),
([{
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}, {
"name": "get_current_weather",
"parameters": {
"city": "Berlin"
}
}], True),
([{
"name": "get_forecast",
"parameters": {
"city": "Vienna",
"days": 7
}
}], True),
([{
"name": "get_forecast",
"parameters": {
"city": "Vienna",
"days": 7
}
}, {
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}], True),
([{
"name": "get_forecast",
"parameters": {
"city": "Vienna",
"days": 7
}
}, {
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}, {
"name": "get_forecast",
"parameters": {
"city": "Berlin",
"days": 7
}
}, {
"name": "get_current_weather",
"parameters": {
"city": "Berlin"
}
}], True),
]
VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]
@pytest.mark.parametrize(
"sample_output, should_match",
VALID_TOOL_OUTPUTS + [
(None, False),
([], False), # empty list cannot be generated
({}, False), # empty object cannot be generated
([{}], False), # list with empty object cannot be generated
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather"
}],
False),
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather",
"parameters": {}
}],
False),
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather",
"parameters": None
}],
False),
(
{ # tool call without lists cannot be generated
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
},
False),
(
[{ # tool call with extra parameters cannot be generated
"name": "get_current_weather",
"parameters": {
"city": "Vienna",
"extra": "value"
}
}],
False),
(
[{ # tool call where parameters are first cannot be generated
"parameters": {
"city": "Vienna"
},
"name": "get_current_weather"
}],
False),
(
[{ # tool call without all required parameters cannot be generated
"name": "get_forecast",
"parameters": {
"city": "Vienna"
}
}],
False),
( # tool call with incorrect name/parameters cannot be generated
[{
"name": "get_weather",
"parameters": {
"city": "Vienna",
"days": 7
}
}], False),
( # tool call with both valid and empty function cannot be generated
[{
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}, {}], False),
])
def test_guided_json(sample_output, should_match):
_compile_and_check(tools=TypeAdapter(
list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS),
sample_output=sample_output,
should_match=should_match)
def update_parameters_none(
tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
tool.function.parameters = None
return tool
def update_parameters_empty_dict(
tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
tool.function.parameters = {}
return tool
@pytest.mark.parametrize(
"sample_output, should_match",
[
(None, False),
([], False), # empty list cannot be generated
({}, False), # empty object cannot be generated
([{}], False), # list with empty object cannot be generated
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather"
}],
False),
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather",
"parameters": None
}],
False),
(
[{ # function with extra parameters cannot be generated
"name": "get_current_weather",
"parameters": {
"extra": "value"
}
}],
False),
(
[{ # only function with empty parameters object is valid
"name": "get_current_weather",
"parameters": {}
}],
True),
])
@pytest.mark.parametrize(
"update_parameters",
[update_parameters_none, update_parameters_empty_dict])
def test_guided_json_without_parameters(sample_output, should_match,
update_parameters):
updated_tools = [deepcopy(EXAMPLE_TOOLS[0])]
tools = TypeAdapter(
list[ChatCompletionToolsParam]).validate_python(updated_tools)
tools = list(map(update_parameters, tools))
assert all([
tool.function.parameters is None or tool.function.parameters == {}
for tool in tools
])
_compile_and_check(tools=tools,
sample_output=sample_output,
should_match=should_match)
@pytest.mark.parametrize("output", VALID_TOOLS)
@pytest.mark.parametrize("empty_params", [False, True])
@pytest.mark.parametrize("delta_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
def test_streaming_output_valid(output, empty_params, delta_len):
self = MagicMock()
output = deepcopy(output)
if empty_params:
output = [{"name": o["name"], "parameters": {}} for o in output]
output_json = json.dumps(output)
previous_text = ""
function_name_returned = False
messages = []
for i in range(0, len(output_json), delta_len):
delta_text = output_json[i:i + delta_len]
current_text = previous_text + delta_text
delta_message, function_name_returned = (
OpenAIServingChat.extract_tool_call_required_streaming(
self,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned))
if delta_message:
messages.append(delta_message)
previous_text = current_text
assert len(messages) > 0
combined_messages = "["
for message in messages:
if message.tool_calls[0].function.name:
if len(combined_messages) > 1:
combined_messages += "},"
combined_messages += '{"name": "' + \
message.tool_calls[0].function.name + \
'", "parameters": ' + \
message.tool_calls[0].function.arguments
else:
combined_messages += message.tool_calls[0].function.arguments
combined_messages += "}]"
assert json.loads(combined_messages) == output
assert json.dumps(json.loads(combined_messages)) == output_json

View File

@ -61,7 +61,7 @@ class OpenAIBaseModel(BaseModel):
field_names = set() field_names = set()
for field_name, field in cls.model_fields.items(): for field_name, field in cls.model_fields.items():
field_names.add(field_name) field_names.add(field_name)
if alias := getattr(field, 'alias', None): if alias := getattr(field, "alias", None):
field_names.add(alias) field_names.add(alias)
cls.field_names = field_names cls.field_names = field_names
@ -70,7 +70,8 @@ class OpenAIBaseModel(BaseModel):
logger.warning( logger.warning(
"The following fields were present in the request " "The following fields were present in the request "
"but ignored: %s", "but ignored: %s",
data.keys() - field_names) data.keys() - field_names,
)
return result return result
@ -234,8 +235,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None
tools: Optional[list[ChatCompletionToolsParam]] = None tools: Optional[list[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"], Literal["auto"], tool_choice: Optional[Union[
ChatCompletionNamedToolChoiceParam]] = "none" Literal["none"],
Literal["auto"],
Literal["required"],
ChatCompletionNamedToolChoiceParam,
]] = "none"
# NOTE this will be ignored by vLLM -- the model determines the behavior # NOTE this will be ignored by vLLM -- the model determines the behavior
parallel_tool_calls: Optional[bool] = False parallel_tool_calls: Optional[bool] = False
@ -340,24 +345,28 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=( description=(
"If specified, will override the default guided decoding backend " "If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either " "of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'")) "'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field( guided_whitespace_pattern: Optional[str] = Field(
default=None, default=None,
description=( description=(
"If specified, will override the default whitespace pattern " "If specified, will override the default whitespace pattern "
"for guided json decoding.")) "for guided json decoding."),
)
priority: int = Field( priority: int = Field(
default=0, default=0,
description=( description=(
"The priority of the request (lower means earlier handling; " "The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling.")) "if the served model does not use priority scheduling."),
)
request_id: str = Field( request_id: str = Field(
default_factory=lambda: f"{random_uuid()}", default_factory=lambda: f"{random_uuid()}",
description=( description=(
"The request_id related to this request. If the caller does " "The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used " "not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response.")) "through out the inference process and return in response."),
)
logits_processors: Optional[LogitsProcessors] = Field( logits_processors: Optional[LogitsProcessors] = Field(
default=None, default=None,
description=( description=(
@ -415,13 +424,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
temperature=temperature, temperature=temperature,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output) include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params( def to_sampling_params(
self, self,
default_max_tokens: int, default_max_tokens: int,
logits_processor_pattern: Optional[str], logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams: default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API # TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens max_tokens = self.max_completion_tokens or self.max_tokens
@ -475,7 +486,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
grammar=self.guided_grammar, grammar=self.guided_grammar,
json_object=guided_json_object, json_object=guided_json_object,
backend=self.guided_decoding_backend, backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern) whitespace_pattern=self.guided_whitespace_pattern,
)
return SamplingParams.from_optional( return SamplingParams.from_optional(
n=self.n, n=self.n,
@ -522,6 +534,41 @@ class ChatCompletionRequest(OpenAIBaseModel):
tool = tools[tool_name] tool = tools[tool_name]
return tool.parameters return tool.parameters
if self.tool_choice == "required":
# Pydantic schema generation cannot be used since the JSON schema
# has to be constructed for a specific instantiation of a tool list
# so that parameters of a function are correctly generated
# based on the chosen function name
def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
return {
"properties": {
"name": {
"type": "string",
"enum": [tool.function.name]
},
# parameters are always generated as '{}' in the final
# output if they are missing from the request
# (i.e. are None or '{}') so the schema is
# updated to produce an empty object in that case
"parameters": tool.function.parameters
if tool.function.parameters else {
"type": "object",
"properties": {}
}
},
"required": ["name", "parameters"]
}
json_schema = {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [get_tool_schema(tool) for tool in self.tools]
}
}
return json_schema
return None return None
@model_validator(mode="before") @model_validator(mode="before")
@ -572,8 +619,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
"You can only use one kind of guided decoding " "You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').") "('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both # you can only either use guided decoding or tools, not both
if guide_count > 1 and data.get("tool_choice", if guide_count > 1 and data.get("tool_choice", "none") not in (
"none") not in ("none", "auto"): "none",
"auto",
"required",
):
raise ValueError( raise ValueError(
"You can only either use guided decoding or tools, not both.") "You can only either use guided decoding or tools, not both.")
return data return data
@ -602,12 +652,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"When using `tool_choice`, `tools` must be set.") "When using `tool_choice`, `tools` must be set.")
# make sure that tool choice is either a named tool # make sure that tool choice is either a named tool
# OR that it's set to "auto" # OR that it's set to "auto" or "required"
if data["tool_choice"] != "auto" and not isinstance( if data["tool_choice"] not in [
data["tool_choice"], dict): "auto", "required"
raise ValueError( ] and not isinstance(data["tool_choice"], dict):
"`tool_choice` must either be a named tool, \"auto\", " raise NotImplementedError(
"or \"none\".") f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\
'Only named tools, "none", "auto" or "required" '\
'are supported.'
)
# ensure that if "tool_choice" is specified as an object, # ensure that if "tool_choice" is specified as an object,
# it matches a valid tool # it matches a valid tool
@ -722,18 +775,21 @@ class CompletionRequest(OpenAIBaseModel):
description=( description=(
"If specified, will override the default guided decoding backend " "If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of " "of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'")) "'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field( guided_whitespace_pattern: Optional[str] = Field(
default=None, default=None,
description=( description=(
"If specified, will override the default whitespace pattern " "If specified, will override the default whitespace pattern "
"for guided json decoding.")) "for guided json decoding."),
)
priority: int = Field( priority: int = Field(
default=0, default=0,
description=( description=(
"The priority of the request (lower means earlier handling; " "The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling.")) "if the served model does not use priority scheduling."),
)
logits_processors: Optional[LogitsProcessors] = Field( logits_processors: Optional[LogitsProcessors] = Field(
default=None, default=None,
description=( description=(
@ -745,6 +801,7 @@ class CompletionRequest(OpenAIBaseModel):
"arguments. For example: {'qualname': " "arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}.")) "{'param': 'value'}}."))
return_tokens_as_token_ids: Optional[bool] = Field( return_tokens_as_token_ids: Optional[bool] = Field(
default=None, default=None,
description=( description=(
@ -789,13 +846,15 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
temperature=temperature, temperature=temperature,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output) include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params( def to_sampling_params(
self, self,
default_max_tokens: int, default_max_tokens: int,
logits_processor_pattern: Optional[str], logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams: default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if default_sampling_params is None: if default_sampling_params is None:
@ -844,7 +903,8 @@ class CompletionRequest(OpenAIBaseModel):
grammar=self.guided_grammar, grammar=self.guided_grammar,
json_object=guided_json_object, json_object=guided_json_object,
backend=self.guided_decoding_backend, backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern) whitespace_pattern=self.guided_whitespace_pattern,
)
return SamplingParams.from_optional( return SamplingParams.from_optional(
n=self.n, n=self.n,
@ -942,7 +1002,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
description=( description=(
"The priority of the request (lower means earlier handling; " "The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling.")) "if the served model does not use priority scheduling."),
)
# doc: end-embedding-extra-params # doc: end-embedding-extra-params
@ -995,7 +1056,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
description=( description=(
"The priority of the request (lower means earlier handling; " "The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling.")) "if the served model does not use priority scheduling."),
)
# doc: end-chat-embedding-extra-params # doc: end-chat-embedding-extra-params
@model_validator(mode="before") @model_validator(mode="before")
@ -1034,7 +1096,8 @@ class ScoreRequest(OpenAIBaseModel):
description=( description=(
"The priority of the request (lower means earlier handling; " "The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling.")) "if the served model does not use priority scheduling."),
)
# doc: end-score-extra-params # doc: end-score-extra-params
@ -1059,7 +1122,8 @@ class RerankRequest(OpenAIBaseModel):
description=( description=(
"The priority of the request (lower means earlier handling; " "The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling.")) "if the served model does not use priority scheduling."),
)
# doc: end-rerank-extra-params # doc: end-rerank-extra-params

View File

@ -2,13 +2,16 @@
import asyncio import asyncio
import json import json
import re
import time import time
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence from collections.abc import Sequence as GenericSequence
from typing import Callable, Final, Optional, Union from typing import Callable, Final, Optional, Union
import jinja2 import jinja2
import partial_json_parser
from fastapi import Request from fastapi import Request
from pydantic import TypeAdapter
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
@ -21,8 +24,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo, DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition,
RequestResponseMetadata, ToolCall, UsageInfo) PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing, from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs) clamp_prompt_logprobs)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
@ -150,12 +153,6 @@ class OpenAIServingChat(OpenAIServing):
tool_parser = self.tool_parser tool_parser = self.tool_parser
# validation for OpenAI tools
# tool_choice = "required" is not supported
if request.tool_choice == "required":
return self.create_error_response(
"tool_choice = \"required\" is not supported!")
if isinstance(tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
# because of issues with pydantic we need to potentially # because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request # re-serialize the tool_calls field of the request
@ -277,6 +274,122 @@ class OpenAIServingChat(OpenAIServing):
return self.response_role return self.response_role
return request.messages[-1]["role"] return request.messages[-1]["role"]
@staticmethod
def _bracket_level(s: str, opening='{', closing='}') -> int:
"""
Calculate the current level of nested brackets in a given string.
"""
level = 0
for char in s:
if char == opening:
level += 1
elif char == closing:
level -= 1
return level
@staticmethod
def _filter_delta_text(delta_text: str,
previous_text: str) -> tuple[str, bool]:
# remove last '},' of the tool definition stemming from the
# "name"/"parameters" outer object or closing ']' of the tool list
# count occurrences of opening and closing curly braces and
# once level 0 is reached stop outputting text
# if 0 is reached while parsing the delta_text we know the current
# tool will finish in this current iteration
bracket_level = OpenAIServingChat._bracket_level(previous_text)
updated_delta, passed_zero = "", False
for c in delta_text:
if c == '{':
bracket_level += 1
passed_zero = bracket_level == 0
elif c == '}':
bracket_level -= 1
passed_zero = bracket_level == 0
if bracket_level != 0:
updated_delta += c
else:
# if a comma is reached at level 0 we can stop
if c == ',':
break
return updated_delta, passed_zero
def extract_tool_call_required_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
function_name_returned: bool,
) -> tuple[Optional[DeltaMessage], bool]:
try:
obj = partial_json_parser.loads(current_text)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
obj = None
# check if the current text is a valid array
# containing a partial tool calling object
# if not repeat
if obj is None or not isinstance(obj, list) or not len(obj) > 0:
function_name_returned = False
delta_message = None
else:
_, finishes_previous_tool = OpenAIServingChat._filter_delta_text(
delta_text, previous_text)
# take the last tool call from the generated list
current_tool_call = obj[-1]
# once parameters have been generated the name is complete as well
if not finishes_previous_tool and ("name" not in current_tool_call
or "parameters"
not in current_tool_call):
function_name_returned = False
delta_message = None
else:
if not function_name_returned:
# get partly generated arguments from the latest tool call
param_match = re.search(r'.*"parameters":\s*(.*)',
current_text)
arguments = param_match.group(1) if param_match else ""
arguments, _ = OpenAIServingChat._filter_delta_text(
arguments, previous_text)
# if this iteration finishes a previous tool call but a
# new incomplete tool is already generated, take the
# previous from the list
if (finishes_previous_tool
and "parameters" not in current_tool_call):
current_tool_call = obj[-2]
function_name_returned = True
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=current_tool_call["name"],
arguments=arguments),
index=len(obj) - 1,
type="function")
])
else:
delta_text, _ = OpenAIServingChat._filter_delta_text(
delta_text, previous_text)
if delta_text != "":
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(
function=DeltaFunctionCall(
# OpenAI API returns None
# instead of name every time
name=None,
arguments=delta_text),
index=len(obj) - 1,
type="function")
])
else:
delta_message = None
return delta_message, function_name_returned
async def chat_completion_stream_generator( async def chat_completion_stream_generator(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
@ -312,6 +425,7 @@ class OpenAIServingChat(OpenAIServing):
self._should_stream_with_reasoning_parsing(request)) self._should_stream_with_reasoning_parsing(request))
all_previous_token_ids: Optional[list[list[int]]] all_previous_token_ids: Optional[list[list[int]]]
function_name_returned: Optional[list[bool]] = None
# Only one of these will be used, thus previous_texts and # Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration. # all_previous_token_ids will not be used twice in the same iteration.
@ -322,6 +436,10 @@ class OpenAIServingChat(OpenAIServing):
# For reasoning parser and tool call all enabled # For reasoning parser and tool call all enabled
added_content_delta_arr = [False] * num_choices added_content_delta_arr = [False] * num_choices
reasoning_end_arr = [False] * num_choices reasoning_end_arr = [False] * num_choices
elif request.tool_choice == "required":
previous_texts = [""] * num_choices
function_name_returned = [False] * num_choices
all_previous_token_ids = None
else: else:
previous_texts, all_previous_token_ids = None, None previous_texts, all_previous_token_ids = None, None
@ -521,6 +639,23 @@ class OpenAIServingChat(OpenAIServing):
index=i) index=i)
]) ])
elif request.tool_choice == "required":
assert previous_texts is not None
assert function_name_returned is not None
previous_text = previous_texts[i]
current_text = previous_text + delta_text
fn_name_returned = function_name_returned[i]
delta_message, function_name_returned[i] = (
self.extract_tool_call_required_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=fn_name_returned))
# update the previous values for the next iteration
previous_texts[i] = current_text
# handle streaming deltas for tools with "auto" tool choice # handle streaming deltas for tools with "auto" tool choice
# and reasoning parser # and reasoning parser
elif tool_choice_auto and self.enable_reasoning: elif tool_choice_auto and self.enable_reasoning:
@ -821,10 +956,10 @@ class OpenAIServingChat(OpenAIServing):
# if auto tools are not enabled, and a named tool choice using # if auto tools are not enabled, and a named tool choice using
# outlines is not being used # outlines is not being used
if (not self.enable_auto_tools if (not self.enable_auto_tools or not self.tool_parser) and \
or not self.tool_parser) and not isinstance( (not isinstance(request.tool_choice,
request.tool_choice, ChatCompletionNamedToolChoiceParam
ChatCompletionNamedToolChoiceParam): ) and request.tool_choice != "required"):
message = ChatMessage(role=role, message = ChatMessage(role=role,
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
content=content) content=content)
@ -845,6 +980,24 @@ class OpenAIServingChat(OpenAIServing):
arguments=content)) arguments=content))
]) ])
elif request.tool_choice and request.tool_choice == "required":
tool_call_class = MistralToolCall if isinstance(
tokenizer, MistralTokenizer) else ToolCall
# the fields of FunctionDefinition are a superset of the
# tool call outputs and can be used for parsing
tool_calls = TypeAdapter(
list[FunctionDefinition]).validate_json(output.text)
message = ChatMessage(
role=role,
content="",
tool_calls=[
tool_call_class(function=FunctionCall(
name=tool_call.name,
arguments=json.dumps(tool_call.parameters)))
for tool_call in tool_calls
])
# if the request doesn't use tool choice # if the request doesn't use tool choice
# OR specifies to not use a tool # OR specifies to not use a tool
elif not request.tool_choice or request.tool_choice == "none": elif not request.tool_choice or request.tool_choice == "none":