[Frontend] Implement Tool Calling with tool_choice='required' (#13483)

Signed-off-by: Liangfu Chen <liangfc@amazon.com>
Signed-off-by: Matt, Matthias <matthias.matt@tuwien.ac.at>
Co-authored-by: Liangfu Chen <liangfc@amazon.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Matthias Matt 2025-04-02 16:45:45 +02:00 committed by GitHub
parent 98d7367b61
commit cefb9e5a28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 868 additions and 93 deletions

View File

@ -1,6 +1,6 @@
# Tool Calling
vLLM currently supports named function calling, as well as the `auto` and `none` options for the `tool_choice` field in the chat completion API. The `tool_choice` option `required` is **not yet supported** but [on the roadmap](gh-issue:13002).
vLLM currently supports named function calling, as well as the `auto`, `required` (as of `vllm>=0.8.3`) and `none` options for the `tool_choice` field in the chat completion API.
## Quickstart
@ -91,6 +91,12 @@ For best results, we recommend ensuring that the expected output format / schema
To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and
specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.
## Required Function Calling
vLLM supports the `tool_choice='required'` option in the chat completion API. Similar to the named function calling, it also uses guided decoding, so this is enabled by default and will work with any supported model. The required guided decoding features (JSON schema with `anyOf`) are currently only supported in the V0 engine with the guided decoding backend `outlines`. However, support for alternative decoding backends are on the [roadmap](https://docs.vllm.ai/en/latest/getting_started/v1_user_guide.html#feature-model) for the V1 engine.
When tool_choice='required' is set, the model is guaranteed to generate one or more tool calls based on the specified tool list in the `tools` parameter. The number of tool calls depends on the user's query. The output format strictly follows the schema defined in the `tools` parameter.
## Automatic Function Calling
To enable this feature, you should set the following flags:

View File

@ -0,0 +1,136 @@
# SPDX-License-Identifier: Apache-2.0
"""
To run this example, you can start the vLLM server
without any specific flags:
```bash
VLLM_USE_V1=0 vllm serve unsloth/Llama-3.2-1B-Instruct \
--guided-decoding-backend outlines
```
This example demonstrates how to generate chat completions
using the OpenAI Python client library.
"""
from openai import OpenAI
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for"
", e.g. 'San Francisco'",
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the "
"city is in, e.g. 'CA' which would mean 'California'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
},
{
"type": "function",
"function": {
"name": "get_forecast",
"description": "Get the weather forecast for a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to get the forecast for, e.g. 'New York'",
},
"state": {
"type":
"string",
"description":
"The two-letter abbreviation for the state, e.g. 'NY'",
},
"days": {
"type":
"integer",
"description":
"Number of days to get the forecast for (1-7)",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "days", "unit"],
},
},
},
]
messages = [
{
"role": "user",
"content": "Hi! How are you doing today?"
},
{
"role": "assistant",
"content": "I'm doing well! How can I help you?"
},
{
"role":
"user",
"content":
"Can you tell me what the current weather is in Dallas \
and the forecast for the next 5 days, in fahrenheit?",
},
]
chat_completion = client.chat.completions.create(
messages=messages,
model=model,
tools=tools,
tool_choice="required",
stream=True # Enable streaming response
)
for chunk in chat_completion:
if chunk.choices and chunk.choices[0].delta.tool_calls:
print(chunk.choices[0].delta.tool_calls)
chat_completion = client.chat.completions.create(messages=messages,
model=model,
tools=tools,
tool_choice="required")
print(chat_completion.choices[0].message.tool_calls)

View File

@ -786,56 +786,135 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, is_v1_server: bool,
@pytest.mark.asyncio
async def test_required_tool_use_not_yet_supported(client: openai.AsyncOpenAI,
sample_json_schema):
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_required_tool_use(client: openai.AsyncOpenAI,
is_v1_server: bool, model_name: str):
if is_v1_server:
pytest.skip("sample_json_schema has features unsupported on V1")
pytest.skip(
"tool_choice='required' requires features unsupported on V1")
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {sample_json_schema}"
}]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description":
"The city to find the weather for, e.g. 'Vienna'",
"default": "Vienna",
},
"country": {
"type":
"string",
"description":
"The country that the city is in, e.g. 'Austria'",
},
"unit": {
"type": "string",
"description":
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["country", "unit"],
},
},
},
{
"type": "function",
"function": {
"name": "get_forecast",
"description": "Get the weather forecast for a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description":
"The city to get the forecast for, e.g. 'Vienna'",
"default": "Vienna",
},
"country": {
"type":
"string",
"description":
"The country that the city is in, e.g. 'Austria'",
},
"days": {
"type":
"integer",
"description":
"Number of days to get the forecast for (1-7)",
},
"unit": {
"type": "string",
"description":
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["country", "days", "unit"],
},
},
},
]
with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_completion_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema
}
}],
tool_choice="required")
messages = [
{
"role": "user",
"content": "Hi! How are you doing today?"
},
{
"role": "assistant",
"content": "I'm doing well! How can I help you?"
},
{
"role":
"user",
"content":
"Can you tell me what the current weather is in Berlin and the "\
"forecast for the next 5 days, in fahrenheit?",
},
]
with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_completion_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema
}
}],
tool_choice="auto")
# Non-streaming test
chat_completion = await client.chat.completions.create(
messages=messages,
model=model_name,
tools=tools,
tool_choice="required",
extra_body=dict(guided_decoding_backend="outlines"),
)
assert chat_completion.choices[0].message.tool_calls is not None
assert len(chat_completion.choices[0].message.tool_calls) > 0
# Streaming test
stream = await client.chat.completions.create(
messages=messages,
model=model_name,
tools=tools,
tool_choice="required",
extra_body=dict(guided_decoding_backend="outlines"),
stream=True,
)
output = []
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.tool_calls:
output.extend(chunk.choices[0].delta.tool_calls)
assert len(output) > 0
@pytest.mark.asyncio
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
is_v1_server: bool,
sample_json_schema):
if is_v1_server:

View File

@ -43,7 +43,8 @@ def test_chat_completion_request_with_no_tools():
assert request.tool_choice == 'none'
def test_chat_completion_request_with_tool_choice_but_no_tools():
@pytest.mark.parametrize('tool_choice', ['auto', 'required'])
def test_chat_completion_request_with_tool_choice_but_no_tools(tool_choice):
with pytest.raises(ValueError,
match="When using `tool_choice`, `tools` must be set."):
ChatCompletionRequest.model_validate({
@ -54,7 +55,7 @@ def test_chat_completion_request_with_tool_choice_but_no_tools():
'model':
'facebook/opt-125m',
'tool_choice':
'auto'
tool_choice
})
with pytest.raises(ValueError,
@ -67,7 +68,7 @@ def test_chat_completion_request_with_tool_choice_but_no_tools():
'model':
'facebook/opt-125m',
'tool_choice':
'auto',
tool_choice,
'tools':
None
})

View File

@ -0,0 +1,336 @@
# SPDX-License-Identifier: Apache-2.0
import json
import re
from copy import deepcopy
from unittest.mock import MagicMock
import pytest
from pydantic import TypeAdapter
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
EXAMPLE_TOOLS = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for"
", e.g. 'San Francisco'",
},
},
"required": ["city"],
"additionalProperties": False
},
},
"strict": True
},
{
"type": "function",
"function": {
"name": "get_forecast",
"description": "Get the weather forecast for a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to get the forecast for, e.g. 'New York'",
},
"days": {
"type":
"integer",
"description":
"Number of days to get the forecast for (1-7)",
},
},
"required": ["city", "days"],
"additionalProperties": False
},
},
"strict": True
},
]
def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
should_match: bool):
self = MagicMock(tool_choice="required", tools=tools)
schema = ChatCompletionRequest._get_guided_json_from_tool(self)
assert isinstance(schema, dict)
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
from outlines_core.fsm.json_schema import build_regex_from_schema
regex = build_regex_from_schema(json.dumps(schema))
compiled = re.compile(regex)
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
assert matches == should_match
VALID_TOOL_OUTPUTS = [
([{
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}], True),
([{
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}, {
"name": "get_current_weather",
"parameters": {
"city": "Berlin"
}
}], True),
([{
"name": "get_forecast",
"parameters": {
"city": "Vienna",
"days": 7
}
}], True),
([{
"name": "get_forecast",
"parameters": {
"city": "Vienna",
"days": 7
}
}, {
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}], True),
([{
"name": "get_forecast",
"parameters": {
"city": "Vienna",
"days": 7
}
}, {
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}, {
"name": "get_forecast",
"parameters": {
"city": "Berlin",
"days": 7
}
}, {
"name": "get_current_weather",
"parameters": {
"city": "Berlin"
}
}], True),
]
VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]
@pytest.mark.parametrize(
"sample_output, should_match",
VALID_TOOL_OUTPUTS + [
(None, False),
([], False), # empty list cannot be generated
({}, False), # empty object cannot be generated
([{}], False), # list with empty object cannot be generated
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather"
}],
False),
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather",
"parameters": {}
}],
False),
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather",
"parameters": None
}],
False),
(
{ # tool call without lists cannot be generated
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
},
False),
(
[{ # tool call with extra parameters cannot be generated
"name": "get_current_weather",
"parameters": {
"city": "Vienna",
"extra": "value"
}
}],
False),
(
[{ # tool call where parameters are first cannot be generated
"parameters": {
"city": "Vienna"
},
"name": "get_current_weather"
}],
False),
(
[{ # tool call without all required parameters cannot be generated
"name": "get_forecast",
"parameters": {
"city": "Vienna"
}
}],
False),
( # tool call with incorrect name/parameters cannot be generated
[{
"name": "get_weather",
"parameters": {
"city": "Vienna",
"days": 7
}
}], False),
( # tool call with both valid and empty function cannot be generated
[{
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}, {}], False),
])
def test_guided_json(sample_output, should_match):
_compile_and_check(tools=TypeAdapter(
list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS),
sample_output=sample_output,
should_match=should_match)
def update_parameters_none(
tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
tool.function.parameters = None
return tool
def update_parameters_empty_dict(
tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
tool.function.parameters = {}
return tool
@pytest.mark.parametrize(
"sample_output, should_match",
[
(None, False),
([], False), # empty list cannot be generated
({}, False), # empty object cannot be generated
([{}], False), # list with empty object cannot be generated
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather"
}],
False),
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather",
"parameters": None
}],
False),
(
[{ # function with extra parameters cannot be generated
"name": "get_current_weather",
"parameters": {
"extra": "value"
}
}],
False),
(
[{ # only function with empty parameters object is valid
"name": "get_current_weather",
"parameters": {}
}],
True),
])
@pytest.mark.parametrize(
"update_parameters",
[update_parameters_none, update_parameters_empty_dict])
def test_guided_json_without_parameters(sample_output, should_match,
update_parameters):
updated_tools = [deepcopy(EXAMPLE_TOOLS[0])]
tools = TypeAdapter(
list[ChatCompletionToolsParam]).validate_python(updated_tools)
tools = list(map(update_parameters, tools))
assert all([
tool.function.parameters is None or tool.function.parameters == {}
for tool in tools
])
_compile_and_check(tools=tools,
sample_output=sample_output,
should_match=should_match)
@pytest.mark.parametrize("output", VALID_TOOLS)
@pytest.mark.parametrize("empty_params", [False, True])
@pytest.mark.parametrize("delta_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
def test_streaming_output_valid(output, empty_params, delta_len):
self = MagicMock()
output = deepcopy(output)
if empty_params:
output = [{"name": o["name"], "parameters": {}} for o in output]
output_json = json.dumps(output)
previous_text = ""
function_name_returned = False
messages = []
for i in range(0, len(output_json), delta_len):
delta_text = output_json[i:i + delta_len]
current_text = previous_text + delta_text
delta_message, function_name_returned = (
OpenAIServingChat.extract_tool_call_required_streaming(
self,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned))
if delta_message:
messages.append(delta_message)
previous_text = current_text
assert len(messages) > 0
combined_messages = "["
for message in messages:
if message.tool_calls[0].function.name:
if len(combined_messages) > 1:
combined_messages += "},"
combined_messages += '{"name": "' + \
message.tool_calls[0].function.name + \
'", "parameters": ' + \
message.tool_calls[0].function.arguments
else:
combined_messages += message.tool_calls[0].function.arguments
combined_messages += "}]"
assert json.loads(combined_messages) == output
assert json.dumps(json.loads(combined_messages)) == output_json

View File

@ -61,7 +61,7 @@ class OpenAIBaseModel(BaseModel):
field_names = set()
for field_name, field in cls.model_fields.items():
field_names.add(field_name)
if alias := getattr(field, 'alias', None):
if alias := getattr(field, "alias", None):
field_names.add(alias)
cls.field_names = field_names
@ -70,7 +70,8 @@ class OpenAIBaseModel(BaseModel):
logger.warning(
"The following fields were present in the request "
"but ignored: %s",
data.keys() - field_names)
data.keys() - field_names,
)
return result
@ -234,8 +235,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
temperature: Optional[float] = None
top_p: Optional[float] = None
tools: Optional[list[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
ChatCompletionNamedToolChoiceParam]] = "none"
tool_choice: Optional[Union[
Literal["none"],
Literal["auto"],
Literal["required"],
ChatCompletionNamedToolChoiceParam,
]] = "none"
# NOTE this will be ignored by vLLM -- the model determines the behavior
parallel_tool_calls: Optional[bool] = False
@ -340,24 +345,28 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"))
"'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
"for guided json decoding."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."))
"through out the inference process and return in response."),
)
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
@ -415,13 +424,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
@ -475,7 +486,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
whitespace_pattern=self.guided_whitespace_pattern,
)
return SamplingParams.from_optional(
n=self.n,
@ -522,6 +534,41 @@ class ChatCompletionRequest(OpenAIBaseModel):
tool = tools[tool_name]
return tool.parameters
if self.tool_choice == "required":
# Pydantic schema generation cannot be used since the JSON schema
# has to be constructed for a specific instantiation of a tool list
# so that parameters of a function are correctly generated
# based on the chosen function name
def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
return {
"properties": {
"name": {
"type": "string",
"enum": [tool.function.name]
},
# parameters are always generated as '{}' in the final
# output if they are missing from the request
# (i.e. are None or '{}') so the schema is
# updated to produce an empty object in that case
"parameters": tool.function.parameters
if tool.function.parameters else {
"type": "object",
"properties": {}
}
},
"required": ["name", "parameters"]
}
json_schema = {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [get_tool_schema(tool) for tool in self.tools]
}
}
return json_schema
return None
@model_validator(mode="before")
@ -572,8 +619,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both
if guide_count > 1 and data.get("tool_choice",
"none") not in ("none", "auto"):
if guide_count > 1 and data.get("tool_choice", "none") not in (
"none",
"auto",
"required",
):
raise ValueError(
"You can only either use guided decoding or tools, not both.")
return data
@ -602,12 +652,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"When using `tool_choice`, `tools` must be set.")
# make sure that tool choice is either a named tool
# OR that it's set to "auto"
if data["tool_choice"] != "auto" and not isinstance(
data["tool_choice"], dict):
raise ValueError(
"`tool_choice` must either be a named tool, \"auto\", "
"or \"none\".")
# OR that it's set to "auto" or "required"
if data["tool_choice"] not in [
"auto", "required"
] and not isinstance(data["tool_choice"], dict):
raise NotImplementedError(
f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\
'Only named tools, "none", "auto" or "required" '\
'are supported.'
)
# ensure that if "tool_choice" is specified as an object,
# it matches a valid tool
@ -722,18 +775,21 @@ class CompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"))
"'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
"for guided json decoding."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
@ -745,6 +801,7 @@ class CompletionRequest(OpenAIBaseModel):
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."))
return_tokens_as_token_ids: Optional[bool] = Field(
default=None,
description=(
@ -789,13 +846,15 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
max_tokens = self.max_tokens
if default_sampling_params is None:
@ -844,7 +903,8 @@ class CompletionRequest(OpenAIBaseModel):
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
whitespace_pattern=self.guided_whitespace_pattern,
)
return SamplingParams.from_optional(
n=self.n,
@ -942,7 +1002,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-embedding-extra-params
@ -995,7 +1056,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-chat-embedding-extra-params
@model_validator(mode="before")
@ -1034,7 +1096,8 @@ class ScoreRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-score-extra-params
@ -1059,7 +1122,8 @@ class RerankRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-rerank-extra-params

View File

@ -2,13 +2,16 @@
import asyncio
import json
import re
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Callable, Final, Optional, Union
import jinja2
import partial_json_parser
from fastapi import Request
from pydantic import TypeAdapter
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
@ -21,8 +24,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
RequestResponseMetadata, ToolCall, UsageInfo)
DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition,
PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
@ -150,12 +153,6 @@ class OpenAIServingChat(OpenAIServing):
tool_parser = self.tool_parser
# validation for OpenAI tools
# tool_choice = "required" is not supported
if request.tool_choice == "required":
return self.create_error_response(
"tool_choice = \"required\" is not supported!")
if isinstance(tokenizer, MistralTokenizer):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
@ -277,6 +274,122 @@ class OpenAIServingChat(OpenAIServing):
return self.response_role
return request.messages[-1]["role"]
@staticmethod
def _bracket_level(s: str, opening='{', closing='}') -> int:
"""
Calculate the current level of nested brackets in a given string.
"""
level = 0
for char in s:
if char == opening:
level += 1
elif char == closing:
level -= 1
return level
@staticmethod
def _filter_delta_text(delta_text: str,
previous_text: str) -> tuple[str, bool]:
# remove last '},' of the tool definition stemming from the
# "name"/"parameters" outer object or closing ']' of the tool list
# count occurrences of opening and closing curly braces and
# once level 0 is reached stop outputting text
# if 0 is reached while parsing the delta_text we know the current
# tool will finish in this current iteration
bracket_level = OpenAIServingChat._bracket_level(previous_text)
updated_delta, passed_zero = "", False
for c in delta_text:
if c == '{':
bracket_level += 1
passed_zero = bracket_level == 0
elif c == '}':
bracket_level -= 1
passed_zero = bracket_level == 0
if bracket_level != 0:
updated_delta += c
else:
# if a comma is reached at level 0 we can stop
if c == ',':
break
return updated_delta, passed_zero
def extract_tool_call_required_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
function_name_returned: bool,
) -> tuple[Optional[DeltaMessage], bool]:
try:
obj = partial_json_parser.loads(current_text)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
obj = None
# check if the current text is a valid array
# containing a partial tool calling object
# if not repeat
if obj is None or not isinstance(obj, list) or not len(obj) > 0:
function_name_returned = False
delta_message = None
else:
_, finishes_previous_tool = OpenAIServingChat._filter_delta_text(
delta_text, previous_text)
# take the last tool call from the generated list
current_tool_call = obj[-1]
# once parameters have been generated the name is complete as well
if not finishes_previous_tool and ("name" not in current_tool_call
or "parameters"
not in current_tool_call):
function_name_returned = False
delta_message = None
else:
if not function_name_returned:
# get partly generated arguments from the latest tool call
param_match = re.search(r'.*"parameters":\s*(.*)',
current_text)
arguments = param_match.group(1) if param_match else ""
arguments, _ = OpenAIServingChat._filter_delta_text(
arguments, previous_text)
# if this iteration finishes a previous tool call but a
# new incomplete tool is already generated, take the
# previous from the list
if (finishes_previous_tool
and "parameters" not in current_tool_call):
current_tool_call = obj[-2]
function_name_returned = True
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=current_tool_call["name"],
arguments=arguments),
index=len(obj) - 1,
type="function")
])
else:
delta_text, _ = OpenAIServingChat._filter_delta_text(
delta_text, previous_text)
if delta_text != "":
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(
function=DeltaFunctionCall(
# OpenAI API returns None
# instead of name every time
name=None,
arguments=delta_text),
index=len(obj) - 1,
type="function")
])
else:
delta_message = None
return delta_message, function_name_returned
async def chat_completion_stream_generator(
self,
request: ChatCompletionRequest,
@ -312,6 +425,7 @@ class OpenAIServingChat(OpenAIServing):
self._should_stream_with_reasoning_parsing(request))
all_previous_token_ids: Optional[list[list[int]]]
function_name_returned: Optional[list[bool]] = None
# Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration.
@ -322,6 +436,10 @@ class OpenAIServingChat(OpenAIServing):
# For reasoning parser and tool call all enabled
added_content_delta_arr = [False] * num_choices
reasoning_end_arr = [False] * num_choices
elif request.tool_choice == "required":
previous_texts = [""] * num_choices
function_name_returned = [False] * num_choices
all_previous_token_ids = None
else:
previous_texts, all_previous_token_ids = None, None
@ -521,6 +639,23 @@ class OpenAIServingChat(OpenAIServing):
index=i)
])
elif request.tool_choice == "required":
assert previous_texts is not None
assert function_name_returned is not None
previous_text = previous_texts[i]
current_text = previous_text + delta_text
fn_name_returned = function_name_returned[i]
delta_message, function_name_returned[i] = (
self.extract_tool_call_required_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=fn_name_returned))
# update the previous values for the next iteration
previous_texts[i] = current_text
# handle streaming deltas for tools with "auto" tool choice
# and reasoning parser
elif tool_choice_auto and self.enable_reasoning:
@ -821,10 +956,10 @@ class OpenAIServingChat(OpenAIServing):
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
if (not self.enable_auto_tools
or not self.tool_parser) and not isinstance(
request.tool_choice,
ChatCompletionNamedToolChoiceParam):
if (not self.enable_auto_tools or not self.tool_parser) and \
(not isinstance(request.tool_choice,
ChatCompletionNamedToolChoiceParam
) and request.tool_choice != "required"):
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=content)
@ -845,6 +980,24 @@ class OpenAIServingChat(OpenAIServing):
arguments=content))
])
elif request.tool_choice and request.tool_choice == "required":
tool_call_class = MistralToolCall if isinstance(
tokenizer, MistralTokenizer) else ToolCall
# the fields of FunctionDefinition are a superset of the
# tool call outputs and can be used for parsing
tool_calls = TypeAdapter(
list[FunctionDefinition]).validate_json(output.text)
message = ChatMessage(
role=role,
content="",
tool_calls=[
tool_call_class(function=FunctionCall(
name=tool_call.name,
arguments=json.dumps(tool_call.parameters)))
for tool_call in tool_calls
])
# if the request doesn't use tool choice
# OR specifies to not use a tool
elif not request.tool_choice or request.tool_choice == "none":