[Frontend] Pythonic tool parser (#9859)

Signed-off-by: Mike Depinet <mike@fixie.ai>
This commit is contained in:
Mike Depinet 2024-11-13 20:14:34 -08:00 committed by GitHub
parent e0853b6508
commit f67ce05d0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 806 additions and 26 deletions

View File

@ -162,7 +162,7 @@ vllm serve <model> --chat-template ./path-to-chat-template.jinja
vLLM community provides a set of chat templates for popular models. You can find them in the examples
directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
With the inclusion of multi-modal chat APIs, the OpenAI spec now accepts chat messages in a new format which specifies
With the inclusion of multi-modal chat APIs, the OpenAI spec now accepts chat messages in a new format which specifies
both a `type` and a `text` field. An example is provided below:
```python
completion = client.chat.completions.create(
@ -172,10 +172,10 @@ completion = client.chat.completions.create(
]
)
```
Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like
Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like
`meta-llama/Llama-Guard-3-1B` that expect the content to be parsed with the new OpenAI spec. In order to choose which
format the content needs to be parsed in by vLLM, please use the `--chat-template-text-format` argument to specify
between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match
between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match
this, unless explicitly specified.
@ -191,8 +191,8 @@ this, unless explicitly specified.
### Config file
The `serve` module can also accept arguments from a config file in
`yaml` format. The arguments in the yaml must be specified using the
long form of the argument outlined [here](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server):
`yaml` format. The arguments in the yaml must be specified using the
long form of the argument outlined [here](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server):
For example:
@ -208,7 +208,7 @@ uvicorn-log-level: "info"
$ vllm serve SOME_MODEL --config config.yaml
```
---
**NOTE**
**NOTE**
In case an argument is supplied simultaneously using command line and the config file, the value from the commandline will take precedence.
The order of priorities is `command line > config file values > defaults`.
@ -222,30 +222,30 @@ Please see below for recommended configuration and chat templates to use when fu
### Named Function Calling
vLLM supports named function calling in the chat completion API by default. It does so using Outlines, so this is
enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a
high-quality one.
vLLM supports named function calling in the chat completion API by default. It does so using Outlines, so this is
enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a
high-quality one.
vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and
specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.
To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and
specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.
### Automatic Function Calling
To enable this feature, you should set the following flags:
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
deems appropriate.
* `--tool-call-parser` -- select the tool parser to use (listed below). Additional tool parsers
* `--tool-call-parser` -- select the tool parser to use (listed below). Additional tool parsers
will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`.
* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`.
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates)
from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json)
If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template!
If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template!
#### Hermes Models (`hermes`)
@ -256,8 +256,8 @@ All Nous Research Hermes-series models newer than Hermes 2 Pro should be support
* `NousResearch/Hermes-3-*`
_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge
step in their creation_.
_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge
step in their creation_.
Flags: `--tool-call-parser hermes`
@ -269,9 +269,9 @@ Supported models:
* Additional mistral function-calling models are compatible as well.
Known issues:
1. Mistral 7B struggles to generate parallel tool calls correctly.
2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
much shorter than what vLLM generates. Since an exception is thrown when this condition
1. Mistral 7B struggles to generate parallel tool calls correctly.
2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
much shorter than what vLLM generates. Since an exception is thrown when this condition
is not met, the following additional chat templates are provided:
* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that
@ -291,11 +291,11 @@ Supported models:
* `meta-llama/Meta-Llama-3.1-405B-Instruct`
* `meta-llama/Meta-Llama-3.1-405B-Instruct-FP8`
The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling).
The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) in Llama-3.2 models, see the `pythonic` tool parser below.
Other tool calling formats like the built in python tool calling or custom tool calling are not supported.
Known issues:
1. Parallel tool calls are not supported.
1. Parallel tool calls are not supported.
2. The model can generate parameters with a wrong format, such as generating
an array serialized as string instead of an array.
@ -341,6 +341,34 @@ AI21's Jamba-1.5 models are supported.
Flags: `--tool-call-parser jamba`
#### Models with Pythonic Tool Calls (`pythonic`)
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.
As a concrete example, these models may look up the weather in San Francisco and Seattle by generating:
```python
[get_weather(city='San Francisco', metric='celsius'), get_weather(city='Seattle', metric='celsius')]
```
Limitations:
* The model must not generate both text and tool calls in the same generation. This may not be hard to change for a specific model, but the community currently lacks consensus on which tokens to emit when starting and ending tool calls. (In particular, the Llama 3.2 models emit no such tokens.)
* Llama's smaller models struggle to use tools effectively.
Example supported models:
* `meta-llama/Llama-3.2-1B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`)
* `meta-llama/Llama-3.2-3B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`)
* `Team-ACE/ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`)
* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`)
Flags: `--tool-call-parser pythonic --chat-template {see_above}`
---
**WARNING**
Llama's smaller models frequently fail to emit tool calls in the correct format. Your mileage may vary.
---
### How to write a tool parser plugin
A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py.

View File

@ -0,0 +1,98 @@
{{- bos_token }}
{%- if custom_tools is defined %}
{%- set tools = custom_tools %}
{%- endif %}
{%- if not tools_in_user_message is defined %}
{%- set tools_in_user_message = false %}
{%- endif %}
{%- if not date_string is defined %}
{%- if strftime_now is defined %}
{%- set date_string = strftime_now("%d %b %Y") %}
{%- else %}
{%- set date_string = "26 Jul 2024" %}
{%- endif %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{#- This block extracts the system message, so we can slot it into the right place. #}
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content']|trim %}
{%- set messages = messages[1:] %}
{%- else %}
{%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %}
{%- endif %}
{#- System message #}
{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
{%- if tools is not none %}
{{- "Environment: ipython\n" }}
{%- endif %}
{{- "Cutting Knowledge Date: December 2023\n" }}
{{- "Today Date: " + date_string + "\n\n" }}
{%- if tools is not none and not tools_in_user_message %}
{{- "You have access to the following functions. To call functions, please respond with a python list of the calls. " }}
{{- 'Respond in the format [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] ' }}
{{- "Do not use variables.\n\n" }}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{%- endif %}
{{- system_message }}
{{- "<|eot_id|>" }}
{#- Custom tools are passed in a user message with some extra guidance #}
{%- if tools_in_user_message and not tools is none %}
{#- Extract the first user message so we can plug it in here #}
{%- if messages | length != 0 %}
{%- set first_user_message = messages[0]['content']|trim %}
{%- set messages = messages[1:] %}
{%- else %}
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
{%- endif %}
{{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
{{- "Given the following functions, please respond with a python list for function calls " }}
{{- "with their proper arguments to best answer the given prompt.\n\n" }}
{{- 'Respond in the format [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] ' }}
{{- "Do not use variables.\n\n" }}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{{- first_user_message + "<|eot_id|>"}}
{%- endif %}
{%- for message in messages %}
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }}
{%- elif 'tool_calls' in message %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n[' -}}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- tool_call.name + '(' -}}
{%- for param in tool_call.arguments %}
{{- param + '=' -}}
{{- "%sr" | format(tool_call.arguments[param]) -}}
{% if not loop.last %}, {% endif %}
{%- endfor %}
{{- ')' -}}
{% if not loop.last %}, {% endif %}
{%- endfor %}
{{- ']<|eot_id|>' -}}
{%- elif message.role == "tool" or message.role == "ipython" %}
{{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
{%- if message.content is mapping %}
{{- message.content | tojson }}
{%- else %}
{{- { "output": message.content } | tojson }}
{%- endif %}
{{- "<|eot_id|>" }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}

View File

@ -0,0 +1,65 @@
{{- bos_token }}
{%- if custom_tools is defined %}
{%- set tools = custom_tools %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{#- This block extracts the system message, so we can slot it into the right place. #}
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content']|trim %}
{%- set messages = messages[1:] %}
{%- else %}
{%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language." %}
{%- endif %}
{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
{%- if tools is not none and not tools_in_user_message %}
{{- "You are an expert in composing functions. You are given a question and a set of possible functions. Based on the question, you will need to make one or more function/tool calls to achieve the purpose.\n" }}
{{- "If none of the function can be used, point it out. If the given question lacks the parameters required by the function, also point it out.\n" }}
{{- "You should only return the function call in tools call sections.\n\n" }}
{{- "If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]\n" }}
{{- "You SHOULD NOT include any other text in the response.\n" }}
{{- "Here is a list of functions in JSON format that you can invoke.\n" }}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{{- "\n" }}
{%- endif %}
{{- system_message }}
{{- "<|eot_id|>" }}
{%- for message in messages %}
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }}
{%- elif 'tool_calls' in message %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n[' -}}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- tool_call.name + '(' -}}
{%- for param in tool_call.arguments %}
{{- param + '=' -}}
{{- "%sr" | format(tool_call.arguments[param]) -}}
{% if not loop.last %}, {% endif %}
{%- endfor %}
{{- ')' -}}
{% if not loop.last %}, {% endif %}
{%- endfor %}
{{- ']<|eot_id|>' -}}
{%- elif message.role == "tool" or message.role == "ipython" %}
{{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
{%- if message.content is mapping %}
{{- message.content | tojson }}
{%- else %}
{{- { "output": message.content } | tojson }}
{%- endif %}
{{- "<|eot_id|>" }}
{%- endif %}
{%- endfor %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}

View File

@ -0,0 +1,160 @@
from typing import List
from unittest.mock import MagicMock
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction, run_tool_extraction_streaming)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
SIMPLE_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "San Francisco", "metric": "celsius"}',
)
MORE_TYPES_FUNCTION_OUTPUT = (
"register_user(name='John Doe', "
"age=37, "
"address={'city': 'San Francisco', 'state': 'CA'}, "
"role=None, "
"passed_test=True, "
"aliases=['John', 'Johnny'])")
MORE_TYPES_FUNCTION_CALL = FunctionCall(
name="register_user",
arguments='{"name": "John Doe", '
'"age": 37, '
'"address": {"city": "San Francisco", "state": "CA"}, '
'"role": null, '
'"passed_test": true, '
'"aliases": ["John", "Johnny"]}',
)
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{}',
)
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"additional_data": {}}',
)
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"steps": []}',
)
ESCAPED_STRING_FUNCTION_OUTPUT = (
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')")
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
)
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
mock_tokenizer)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(tool_parser,
model_output,
streaming=streaming)
assert content == model_output
assert len(tool_calls) == 0
TEST_CASES = [
pytest.param(True,
f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL],
id="simple_streaming"),
pytest.param(False,
f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL],
id="simple_nonstreaming"),
pytest.param(True,
f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming"),
pytest.param(False,
f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming"),
pytest.param(True,
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_streaming"),
pytest.param(False,
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_nonstreaming"),
pytest.param(True,
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_streaming"),
pytest.param(False,
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_nonstreaming"),
pytest.param(True,
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL],
id="empty_list_streaming"),
pytest.param(False,
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL],
id="empty_list_nonstreaming"),
pytest.param(True,
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_streaming"),
pytest.param(False,
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_nonstreaming"),
pytest.param(True,
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_streaming"),
pytest.param(False,
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_nonstreaming"),
]
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls",
TEST_CASES)
def test_tool_call(streaming: bool, model_output: str,
expected_tool_calls: List[FunctionCall]):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
mock_tokenizer)
content, tool_calls = run_tool_extraction(tool_parser,
model_output,
streaming=streaming)
assert content is None
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function == expected
def test_streaming_tool_call_with_large_steps():
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
mock_tokenizer)
model_output_deltas = [
"[get_weather(city='San",
" Francisco', metric='celsius'), "
f"{PARAMETERLESS_FUNCTION_OUTPUT}, "
f"{EMPTY_LIST_FUNCTION_OUTPUT}]",
]
reconstructor = run_tool_extraction_streaming(
tool_parser, model_output_deltas, assert_one_tool_per_delta=False)
assert reconstructor.other_content == ""
assert len(reconstructor.tool_calls) == 3
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL

View File

@ -0,0 +1,123 @@
from typing import Iterable, List, Tuple, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers import ToolParser
class StreamingToolReconstructor:
def __init__(self, assert_one_tool_per_delta: bool = True):
self.tool_calls: List[ToolCall] = []
self.other_content: str = ""
self._assert_one_tool_per_delta = assert_one_tool_per_delta
def append_delta(self, delta: DeltaMessage):
if delta.content is not None:
self.other_content += delta.content
else:
assert delta.tool_calls, (
"Streaming results should have either content or tool calls "
"(or both)")
if self._assert_one_tool_per_delta:
# Note: This isn't strictly required by the API and may not be
# possible to adhere to depending on the token space and number of
# tokens per streamed response from the model, but it is required
# by tool_use tests, so we enforce it here by default also.
assert len(delta.tool_calls) < 2, (
"Streaming should include only one tool call per update.")
for call_delta in delta.tool_calls:
assert call_delta.type == "function", (
"Streaming tool calls should only emit function calls. Got "
f"{call_delta.type}")
current_tool_call = self.tool_calls[
call_delta.index] if call_delta.index < len(
self.tool_calls) else None
if current_tool_call:
assert (not call_delta.function.name), (
"Streaming tool calls should emit the full function name "
f"exactly once. Got {call_delta.function.name}")
assert (not call_delta.id), (
"Streaming tool calls must emit function id only once. Got "
f"{call_delta.id}")
assert (call_delta.index == len(self.tool_calls) - 1), (
f"Incorrect index for tool delta. Got {call_delta.index}, "
f"expected {len(self.tool_calls) - 1}")
current_tool_call.function.arguments += (
call_delta.function.arguments)
else:
assert call_delta.id is not None, (
"Streaming tool calls must have an id on first appearance")
assert call_delta.function.name is not None, (
"Streaming tool calls must have a function name on first "
"appearance")
assert call_delta.index == len(self.tool_calls), (
f"Incorrect index for tool delta. Got {call_delta.index}, "
f"expected {len(self.tool_calls)}")
self.tool_calls.append(
ToolCall(id=call_delta.id,
function=FunctionCall(
name=call_delta.function.name,
arguments=call_delta.function.arguments
or "")))
def run_tool_extraction(
tool_parser: ToolParser,
model_output: str,
request: Union[ChatCompletionRequest, None] = None,
streaming: bool = False,
assert_one_tool_per_delta: bool = True,
) -> Tuple[Union[str, None], List[ToolCall]]:
if streaming:
reconstructor = run_tool_extraction_streaming(
tool_parser,
model_output,
request,
assert_one_tool_per_delta=assert_one_tool_per_delta)
return reconstructor.other_content or None, reconstructor.tool_calls
else:
extracted = run_tool_extraction_nonstreaming(tool_parser, model_output,
request)
assert extracted.tools_called == bool(extracted.tool_calls)
return extracted.content, extracted.tool_calls
def run_tool_extraction_nonstreaming(
tool_parser: ToolParser,
model_output: str,
request: Union[ChatCompletionRequest, None] = None
) -> ExtractedToolCallInformation:
request = request or ChatCompletionRequest(messages=[], model="test-model")
return tool_parser.extract_tool_calls(model_output, request)
def run_tool_extraction_streaming(
tool_parser: ToolParser,
model_deltas: Iterable[str],
request: Union[ChatCompletionRequest, None] = None,
assert_one_tool_per_delta: bool = True,
) -> StreamingToolReconstructor:
request = request or ChatCompletionRequest(messages=[], model="test-model")
reconstructor = StreamingToolReconstructor(
assert_one_tool_per_delta=assert_one_tool_per_delta)
previous_text = ""
previous_tokens: List[int] = []
for delta in model_deltas:
token_delta = [
tool_parser.vocab.get(token)
for token in tool_parser.model_tokenizer.tokenize(delta)
if token in tool_parser.vocab
]
current_text = previous_text + delta
current_tokens = previous_tokens + token_delta
delta_message = tool_parser.extract_tool_calls_streaming(
previous_text, current_text, delta, previous_tokens,
current_tokens, token_delta, request)
if delta_message is not None:
reconstructor.append_delta(delta_message)
previous_text = current_text
previous_tokens = current_tokens
return reconstructor

View File

@ -122,7 +122,17 @@ CONFIGS: Dict[str, ServerConfig] = {
],
"supports_parallel":
False,
}
},
"toolACE": {
"model":
"Team-ACE/ToolACE-8B",
"arguments": [
"--tool-call-parser", "pythonic", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_toolace.jinja")
],
"supports_parallel":
True,
},
}
WEATHER_TOOL: ChatCompletionToolParam = {

View File

@ -74,6 +74,11 @@ class OpenAIServingChat(OpenAIServing):
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.enable_auto_tools:
try:
if (tool_parser == "pythonic" and
model_config.model.startswith("meta-llama/Llama-3.2")):
logger.warning(
"Llama3.2 models may struggle to emit valid pythonic"
" tool calls")
self.tool_parser = ToolParserManager.get_tool_parser(
tool_parser)
except Exception as e:

View File

@ -6,9 +6,11 @@ from .internlm2_tool_parser import Internlm2ToolParser
from .jamba_tool_parser import JambaToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser
from .pythonic_tool_parser import PythonicToolParser
__all__ = [
"ToolParser", "ToolParserManager", "Granite20bFCToolParser",
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser"
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
"PythonicToolParser"
]

View File

@ -0,0 +1,289 @@
import ast
import json
import re
from typing import Any, Sequence, Tuple, Union
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
@ToolParserManager.register_module("pythonic")
class PythonicToolParser(ToolParser):
"""
Tool call parser for models that produce tool calls in a pythonic style,
such as Llama 3.2 models.
Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
"""
# TODO(mdepinet): Possible future improvements:
# 1. Support text + tools separated by either <|python_tag|> or \n\n
# 2. Support tools outside of a list (or separated by a semicolon).
# This depends on item 1 for consistent streaming.
# Neither of these are necessary for e.g. ToolACE, but both would help make
# Llama3.2 models more reliable.
TOOL_CALL_REGEX = re.compile(
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
re.DOTALL)
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# Rename for readability. This is NOT a tool id.
@property
def current_tool_index(self) -> int:
return self.current_tool_id
@current_tool_index.setter
def current_tool_index(self, value: int) -> None:
self.current_tool_id = value
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
if not (self.TOOL_CALL_REGEX.match(model_output)):
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
module = ast.parse(model_output)
parsed = getattr(module.body[0], "value", None)
if isinstance(parsed, ast.List) and all(
isinstance(e, ast.Call) for e in parsed.elts):
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
_handle_single_tool(e) # type: ignore
for e in parsed.elts
],
content=None)
else:
raise _UnexpectedAstError(
"Tool output must be a list of function calls")
except Exception:
logger.exception("Error in extracting tool call from response.")
# Treat as regular text
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if not current_text.startswith("["):
return DeltaMessage(content=delta_text)
try:
valid_and_added_text = _make_valid_python(current_text)
if valid_and_added_text is None:
return None
valid_text, added_text = valid_and_added_text
module = ast.parse(valid_text)
parsed = getattr(module.body[0], "value", None)
if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts):
raise _UnexpectedAstError(
"Tool output must be a list of function calls")
tool_calls = [
_handle_single_tool(e) # type: ignore
for e in parsed.elts
]
tool_deltas = []
for index, new_call in enumerate(tool_calls):
if index < self.current_tool_index:
continue
self.current_tool_index = index
if len(self.streamed_args_for_tool) == index:
self.streamed_args_for_tool.append("")
new_call_complete = index < len(
tool_calls) - 1 or ")]" not in added_text
if new_call_complete:
self.current_tool_index += 1
withheld_suffix = (added_text[:-2]
if not new_call_complete else "")
if not new_call_complete and added_text[-2] == ")":
# Function call is incomplete. Withhold the closing bracket.
withheld_suffix = withheld_suffix + "}"
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta(self.streamed_args_for_tool[index],
new_call, index, withheld_suffix)
if delta is not None:
tool_deltas.append(delta)
if (delta.function is not None
and delta.function.arguments is not None):
self.streamed_args_for_tool[
index] += delta.function.arguments
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining it's final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if tool_deltas and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
if tool_deltas:
return DeltaMessage(tool_calls=tool_deltas)
elif not added_text and self.current_tool_id > 0:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return DeltaMessage(content='')
else:
return None
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError(
"Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(type="function",
function=FunctionCall(name=function_name,
arguments=json.dumps(arguments)))
def _make_valid_python(text: str) -> Union[Tuple[str, str], None]:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[:text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[:text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if bracket_stack and bracket_stack[-1] == "[" and not text.endswith(
"[") and not text.endswith(")"):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
index: int,
withheld_suffix: str) -> Union[DeltaToolCall, None]:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[:-len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(id=new_call.id,
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
))
arg_diff = new_call_args[len(previously_sent_args):]
return DeltaToolCall(
id="", index=index, function=DeltaFunctionCall(
arguments=arg_diff)) if arg_diff else None