[Feature] OpenAI-Compatible Tools API + Streaming for Hermes & Mistral models (#5649)

Co-authored-by: constellate <constellate@1-ai-appserver-staging.codereach.com>
Co-authored-by: Kyle Mistele <kyle@constellate.ai>
This commit is contained in:
Kyle Mistele 2024-09-04 15:18:13 -05:00 committed by GitHub
parent 561d6f8077
commit e02ce498be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 2591 additions and 86 deletions

View File

@ -92,6 +92,7 @@ steps:
- pytest -v -s entrypoints/openai
- pytest -v -s entrypoints/test_chat_utils.py
- label: Distributed Tests (4 GPUs) # 10min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
@ -271,6 +272,15 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- bash ./run-tests.sh -c configs/models-small.txt -t 1
- label: OpenAI-Compatible Tool Use # 20 min
fast_check: false
mirror_hardwares: [ amd ]
source_file_dependencies:
- vllm/
- tests/tool_use
commands:
- pytest -v -s tool_use
##### 1 GPU test #####
##### multi gpus test #####

View File

@ -110,6 +110,14 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
:func: create_parser_for_docs
:prog: vllm serve
```
## Tool Calling in the Chat Completion API
### Named Function Calling
vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is
enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a
high-quality one.
To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and
specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.
### Config file
@ -140,10 +148,52 @@ The order of priorities is `command line > config file values > defaults`.
## Tool calling in the chat completion API
vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap.
To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter.
It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.**
It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt.
vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
Please refer to the OpenAI API reference documentation for more information.
### Automatic Function Calling
To enable this feature, you should set the following flags:
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
deems appropriate.
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. Additional tool parsers
will continue to be added in the future.
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
that contain previously generated tool calls. Hermes and Mistral models have tool-compatible chat templates in their
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates)
from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json)
If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template!
#### Hermes Models
All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported.
* `NousResearch/Hermes-2-Pro-*`
* `NousResearch/Hermes-2-Theta-*`
* `NousResearch/Hermes-3-*`
_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge
step in their creation_.
Flags: `--tool-call-parser hermes`
#### Mistral Models
Supported models:
* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed)
* Additional mistral function-calling models are compatible as well.
Known issues:
1. Mistral 7B struggles to generate parallel tool calls correctly.
2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
much shorter than what vLLM generates. Since an exception is thrown when this condition
is not met, the following additional chat templates are provided:
* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that
it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits)
* `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt
when tools are provided, that results in much better reliability when working with parallel tool calling.
Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`

View File

@ -0,0 +1,162 @@
"""
Set up this example by starting a vLLM OpenAI-compatible server with tool call
options enabled. For example:
IMPORTANT: for mistral, you must use one of the provided mistral tool call
templates, or your own - the model default doesn't work for tool calls with vLLM
See the vLLM docs on OpenAI server & tool calling for more details.
vllm serve --model mistralai/Mistral-7B-Instruct-v0.3 \
--chat-template examples/tool_chat_template_mistral.jinja \
--enable-auto-tool-choice --tool-call-parser mistral
OR
vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \
--chat-template examples/tool_chat_template_hermes.jinja \
--enable-auto-tool-choice --tool-call-parser hermes
"""
import json
from openai import OpenAI
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
tools = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city", "state", "unit"]
}
}
}]
messages = [{
"role": "user",
"content": "Hi! How are you doing today?"
}, {
"role": "assistant",
"content": "I'm doing well! How can I help you?"
}, {
"role":
"user",
"content":
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]
chat_completion = client.chat.completions.create(messages=messages,
model=model,
tools=tools)
print("Chat completion results:")
print(chat_completion)
print("\n\n")
tool_calls_stream = client.chat.completions.create(messages=messages,
model=model,
tools=tools,
stream=True)
chunks = []
for chunk in tool_calls_stream:
chunks.append(chunk)
if chunk.choices[0].delta.tool_calls:
print(chunk.choices[0].delta.tool_calls[0])
else:
print(chunk.choices[0].delta)
arguments = []
tool_call_idx = -1
for chunk in chunks:
if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.index != tool_call_idx:
if tool_call_idx >= 0:
print(
f"streamed tool call arguments: {arguments[tool_call_idx]}"
)
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
arguments.append("")
if tool_call.id:
print(f"streamed tool call id: {tool_call.id} ")
if tool_call.function:
if tool_call.function.name:
print(f"streamed tool call name: {tool_call.function.name}")
if tool_call.function.arguments:
arguments[tool_call_idx] += tool_call.function.arguments
if len(arguments):
print(f"streamed tool call arguments: {arguments[-1]}")
print("\n\n")
messages.append({
"role": "assistant",
"tool_calls": chat_completion.choices[0].message.tool_calls
})
# Now, simulate a tool call
def get_current_weather(city: str, state: str, unit: 'str'):
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
"partly cloudly, with highs in the 90's.")
available_tools = {"get_current_weather": get_current_weather}
completion_tool_calls = chat_completion.choices[0].message.tool_calls
for call in completion_tool_calls:
tool_to_call = available_tools[call.function.name]
args = json.loads(call.function.arguments)
result = tool_to_call(**args)
print(result)
messages.append({
"role": "tool",
"content": result,
"tool_call_id": call.id,
"name": call.function.name
})
chat_completion_2 = client.chat.completions.create(messages=messages,
model=model,
tools=tools,
stream=False)
print("\n\n")
print(chat_completion_2)

View File

@ -0,0 +1,129 @@
{%- macro json_to_python_type(json_spec) %}
{%- set basic_type_map = {
"string": "str",
"number": "float",
"integer": "int",
"boolean": "bool"
} %}
{%- if basic_type_map[json_spec.type] is defined %}
{{- basic_type_map[json_spec.type] }}
{%- elif json_spec.type == "array" %}
{{- "list[" + json_to_python_type(json_spec|items) + "]" }}
{%- elif json_spec.type == "object" %}
{%- if json_spec.additionalProperties is defined %}
{{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']' }}
{%- else %}
{{- "dict" }}
{%- endif %}
{%- elif json_spec.type is iterable %}
{{- "Union[" }}
{%- for t in json_spec.type %}
{{- json_to_python_type({"type": t}) }}
{%- if not loop.last %}
{{- "," }}
{%- endif %}
{%- endfor %}
{{- "]" }}
{%- else %}
{{- "Any" }}
{%- endif %}
{%- endmacro %}
{{- bos_token }}
{{- "<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> " }}
{%- if tools is iterable and tools | length > 0 %}
{%- for tool in tools %}
{%- if tool.function is defined %}
{%- set tool = tool.function %}
{%- endif %}
{{- '{"type": "function", "function": ' }}
{{- '{"name": "' + tool.name + '", ' }}
{{- '"description": "' + tool.name + '(' }}
{%- for param_name, param_fields in tool.parameters.properties|items %}
{{- param_name + ": " + json_to_python_type(param_fields) }}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- ")" }}
{%- if tool.return is defined %}
{{- " -> " + json_to_python_type(tool.return) }}
{%- endif %}
{{- " - " + tool.description + "\n\n" }}
{%- for param_name, param_fields in tool.parameters.properties|items %}
{%- if loop.first %}
{{- " Args:\n" }}
{%- endif %}
{{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }}
{%- endfor %}
{%- if tool.return is defined and tool.return.description is defined %}
{{- "\n Returns:\n " + tool.return.description }}
{%- endif %}
{{- '"' }}
{{- ', "parameters": ' }}
{%- if tool.parameters.properties | length == 0 %}
{{- "{}" }}
{%- else %}
{{- tool.parameters|tojson }}
{%- endif %}
{{- "}" }}
{%- if not loop.last %}
{{- "\n" }}
{%- endif %}
{%- endfor %}
{%- endif %}
{{- " </tools>" }}
{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}}
' }}
{{- "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
" }}
{{- "<tool_call>
" }}
{{- '{"name": <function-name>, "arguments": <args-dict>}
' }}
{{- '</tool_call><|im_end|>' }}
{%- for message in messages %}
{%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %}
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" and message.tool_calls is defined %}
{{- '<|im_start|>' + message.role }}
{%- for tool_call in message.tool_calls %}
{{- '\n<tool_call>\n' }}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '{' }}
{{- '"name": "' }}
{{- tool_call.name }}
{{- '"}' }}
{{- ', ' }}
{%- if tool_call.arguments is defined %}
{{- '"arguments": ' }}
{{- tool_call.arguments|tojson }}
{%- endif %}
{{- '\n</tool_call>' }}
{%- endfor %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.previtem and loop.previtem.role != "tool" %}
{{- '<|im_start|>tool\n' }}
{%- endif %}
{{- '<tool_response>\n' }}
{{- message.content }}
{%- if not loop.last %}
{{- '\n</tool_response>\n' }}
{%- else %}
{{- '\n</tool_response>' }}
{%- endif %}
{%- if not loop.last and loop.nextitem.role != "tool" %}
{{- '<|im_end|>' }}
{%- elif loop.last %}
{{- '<|im_end|>' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}

View File

@ -0,0 +1,86 @@
{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif %}
{%- endfor %}
{{- bos_token }}
{%- for message in loop_messages %}
{%- if message["role"] == "user" %}
{%- if tools is not none and (message == user_messages[-1]) %}
{{- "[AVAILABLE_TOOLS] [" }}
{%- for tool in tools %}
{%- set tool = tool.function %}
{{- '{"type": "function", "function": {' }}
{%- for key, val in tool.items() if key != "return" %}
{%- if val is string %}
{{- '"' + key + '": "' + val + '"' }}
{%- else %}
{{- '"' + key + '": ' + val|tojson }}
{%- endif %}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- "}}" }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{{- "[/AVAILABLE_TOOLS]" }}
{%- endif %}
{%- if loop.last and system_message is defined %}
{{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }}
{%- else %}
{{- "[INST] " + message["content"] + "[/INST]" }}
{%- endif %}
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
{%- if message.tool_calls is defined %}
{%- set tool_calls = message.tool_calls %}
{%- else %}
{%- set tool_calls = message.content %}
{%- endif %}
{{- "[TOOL_CALLS] [" }}
{%- for tool_call in tool_calls %}
{%- set out = tool_call.function|tojson %}
{{- out[:-1] }}
{%- if not tool_call.id is defined or tool_call.id|length < 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
{%- endif %}
{{- ', "id": "' + tool_call.id[-9:] + '"}' }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" + eos_token }}
{%- endif %}
{%- endfor %}
{%- elif message["role"] == "assistant" %}
{{- " " + message["content"] + eos_token }}
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
{%- if message.content is defined and message.content.content is defined %}
{%- set content = message.content.content %}
{%- else %}
{%- set content = message.content %}
{%- endif %}
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
{%- endif %}
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
{%- else %}
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}

View File

@ -0,0 +1,94 @@
{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{%- if tools is defined %}
{%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %}
{%- if system_message is defined %}
{%- set system_message = parallel_tool_prompt + "\n\n" + system_message %}
{%- else %}
{%- set system_message = parallel_tool_prompt %}
{%- endif %}
{%- endif %}
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif %}
{%- endfor %}
{{- bos_token }}
{%- for message in loop_messages %}
{%- if message["role"] == "user" %}
{%- if tools is not none and (message == user_messages[-1]) %}
{{- "[AVAILABLE_TOOLS] [" }}
{%- for tool in tools %}
{%- set tool = tool.function %}
{{- '{"type": "function", "function": {' }}
{%- for key, val in tool.items() if key != "return" %}
{%- if val is string %}
{{- '"' + key + '": "' + val + '"' }}
{%- else %}
{{- '"' + key + '": ' + val|tojson }}
{%- endif %}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- "}}" }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{{- "[/AVAILABLE_TOOLS]" }}
{%- endif %}
{%- if loop.last and system_message is defined %}
{{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }}
{%- else %}
{{- "[INST] " + message["content"] + "[/INST]" }}
{%- endif %}
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
{%- if message.tool_calls is defined %}
{%- set tool_calls = message.tool_calls %}
{%- else %}
{%- set tool_calls = message.content %}
{%- endif %}
{{- "[TOOL_CALLS] [" }}
{%- for tool_call in tool_calls %}
{%- set out = tool_call.function|tojson %}
{{- out[:-1] }}
{%- if not tool_call.id is defined or tool_call.id|length < 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
{%- endif %}
{{- ', "id": "' + tool_call.id[-9:] + '"}' }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" + eos_token }}
{%- endif %}
{%- endfor %}
{%- elif message["role"] == "assistant" %}
{{- " " + message["content"] + eos_token }}
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
{%- if message.content is defined and message.content.content is defined %}
{%- set content = message.content.content %}
{%- else %}
{%- set content = message.content %}
{%- endif %}
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
{%- endif %}
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
{%- else %}
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}

View File

@ -20,6 +20,7 @@ lm-format-enforcer == 0.10.6
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
partial-json-parser # used for parsing partial JSON outputs
pyzmq
msgspec
gguf == 0.9.1

View File

View File

@ -0,0 +1,32 @@
import pytest
import pytest_asyncio
from huggingface_hub import snapshot_download
from tests.utils import RemoteOpenAIServer
from .utils import ARGS, CONFIGS, ServerConfig
# for each server config, download the model and return the config
@pytest.fixture(scope="session", params=CONFIGS.keys())
def server_config(request):
config = CONFIGS[request.param]
# download model and tokenizer using transformers
snapshot_download(config["model"])
yield CONFIGS[request.param]
# run this for each server config
@pytest.fixture(scope="session")
def server(request, server_config: ServerConfig):
model = server_config["model"]
args_for_model = server_config["arguments"]
with RemoteOpenAIServer(model, ARGS + args_for_model,
max_wait_seconds=480) as server:
yield server
@pytest_asyncio.fixture
async def client(server: RemoteOpenAIServer):
async with server.get_async_client() as async_client:
yield async_client

View File

@ -0,0 +1,143 @@
from typing import List
import openai
import pytest
from .utils import MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL
# test: make sure chat completions without tools provided work even when tools
# are enabled. This makes sure tool call chat templates work, AND that the tool
# parser stream processing doesn't change the output of the model.
@pytest.mark.asyncio
async def test_chat_completion_without_tools(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
temperature=0,
max_tokens=150,
model=model_name,
logprobs=False)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
output_text = chat_completion.choices[0].message.content
# check to make sure we got text
assert output_text is not None
assert len(output_text) > 0
assert stop_reason != "tool_calls"
# check to make sure no tool calls were returned
assert (choice.message.tool_calls is None
or len(choice.message.tool_calls) == 0)
# make the same request, streaming
stream = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
temperature=0,
max_tokens=150,
model=model_name,
logprobs=False,
stream=True,
)
chunks: List[str] = []
finish_reason_count = 0
role_sent: bool = False
# assemble streamed chunks
async for chunk in stream:
delta = chunk.choices[0].delta
# make sure the role is assistant
if delta.role:
assert not role_sent
assert delta.role == 'assistant'
role_sent = True
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == choice.finish_reason
# make sure tool call chunks aren't being streamed
assert not delta.tool_calls or len(delta.tool_calls) == 0
# make sure the role was sent, only 1 finish reason was sent, that chunks
# were in fact sent, and that the chunks match non-streaming
assert role_sent
assert finish_reason_count == 1
assert len(chunks)
assert "".join(chunks) == output_text
# test: conversation with tools enabled and provided that should not invoke
# tools, to make sure we can still get normal chat completion responses
# and that they won't be parsed as tools
@pytest.mark.asyncio
async def test_chat_completion_with_tools(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
temperature=0,
max_tokens=150,
model=model_name,
tools=[WEATHER_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
output_text = chat_completion.choices[0].message.content
# check to make sure we got text
assert output_text is not None
assert stop_reason != 'tool_calls'
assert len(output_text) > 0
# check to make sure no tool calls were returned
assert (choice.message.tool_calls is None
or len(choice.message.tool_calls) == 0)
# make the same request, streaming
stream = await client.chat.completions.create(
messages=MESSAGES_WITHOUT_TOOLS,
temperature=0,
max_tokens=150,
model=model_name,
logprobs=False,
tools=[WEATHER_TOOL],
stream=True,
)
chunks: List[str] = []
finish_reason_count = 0
role_sent: bool = False
# assemble streamed chunks
async for chunk in stream:
delta = chunk.choices[0].delta
# make sure the role is assistant
if delta.role:
assert delta.role == 'assistant'
role_sent = True
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# make sure tool call chunks aren't being streamed
assert not delta.tool_calls or len(delta.tool_calls) == 0
# make sure the role was sent, only 1 finish reason was sent, that chunks
# were in fact sent, and that the chunks match non-streaming
assert role_sent
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == stop_reason
assert chunk.choices[0].finish_reason != 'tool_calls'
assert len(chunks)
assert "".join(chunks) == output_text

View File

@ -0,0 +1,193 @@
import json
from typing import Dict, List, Optional
import openai
import pytest
from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL,
WEATHER_TOOL)
# test: getting the model to generate parallel tool calls (streaming/not)
# when requested. NOTE that not all models may support this, so some exclusions
# may be added in the future. e.g. llama 3.1 models are not designed to support
# parallel tool calls.
@pytest.mark.asyncio
async def test_parallel_tool_calls(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
temperature=0,
max_tokens=200,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
# make sure 2 tool calls are present
assert choice.message.role == "assistant"
assert non_streamed_tool_calls is not None
assert len(non_streamed_tool_calls) == 2
for tool_call in non_streamed_tool_calls:
# make sure the tool includes a function and ID
assert tool_call.type == "function"
assert tool_call.function is not None
assert isinstance(tool_call.id, str)
assert len(tool_call.id) > 16
# make sure the weather tool was called correctly
assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
assert isinstance(tool_call.function.arguments, str)
parsed_arguments = json.loads(tool_call.function.arguments)
assert isinstance(parsed_arguments, Dict)
assert isinstance(parsed_arguments.get("city"), str)
assert isinstance(parsed_arguments.get("state"), str)
assert stop_reason == "tool_calls"
# make the same request, streaming
stream = await client.chat.completions.create(
model=model_name,
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
temperature=0,
max_tokens=200,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
role_name: Optional[str] = None
finish_reason_count: int = 0
tool_call_names: List[str] = []
tool_call_args: List[str] = []
tool_call_idx: int = -1
tool_call_id_count: int = 0
async for chunk in stream:
# if there's a finish reason make sure it's tools
if chunk.choices[0].finish_reason:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == 'tool_calls'
# if a role is being streamed make sure it wasn't already set to
# something else
if chunk.choices[0].delta.role:
assert not role_name or role_name == 'assistant'
role_name = 'assistant'
# if a tool call is streamed make sure there's exactly one
# (based on the request parameters
streamed_tool_calls = chunk.choices[0].delta.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
# make sure only one diff is present - correct even for parallel
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
# if a new tool is being called, set up empty arguments
if tool_call.index != tool_call_idx:
tool_call_idx = tool_call.index
tool_call_args.append("")
# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id:
tool_call_id_count += 1
assert (isinstance(tool_call.id, str)
and (len(tool_call.id) > 16))
# if parts of the function start being streamed
if tool_call.function:
# if the function name is defined, set it. it should be streamed
# IN ENTIRETY, exactly one time.
if tool_call.function.name:
assert isinstance(tool_call.function.name, str)
tool_call_names.append(tool_call.function.name)
if tool_call.function.arguments:
# make sure they're a string and then add them to the list
assert isinstance(tool_call.function.arguments, str)
tool_call_args[
tool_call.index] += tool_call.function.arguments
assert finish_reason_count == 1
assert role_name == 'assistant'
assert (len(non_streamed_tool_calls) == len(tool_call_names) ==
len(tool_call_args))
for i in range(2):
assert non_streamed_tool_calls[i].function.name == tool_call_names[i]
streamed_args = json.loads(tool_call_args[i])
non_streamed_args = json.loads(
non_streamed_tool_calls[i].function.arguments)
assert streamed_args == non_streamed_args
# test: providing parallel tool calls back to the model to get a response
# (streaming/not)
@pytest.mark.asyncio
async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
temperature=0,
max_tokens=200,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
assert choice.finish_reason != "tool_calls" # "stop" or "length"
assert choice.message.role == "assistant"
assert choice.message.tool_calls is None \
or len(choice.message.tool_calls) == 0
assert choice.message.content is not None
assert "98" in choice.message.content # Dallas temp in tool response
assert "78" in choice.message.content # Orlando temp in tool response
stream = await client.chat.completions.create(
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
temperature=0,
max_tokens=200,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
chunks: List[str] = []
finish_reason_count = 0
role_sent: bool = False
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert not role_sent
assert delta.role == "assistant"
role_sent = True
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == choice.finish_reason
assert not delta.tool_calls or len(delta.tool_calls) == 0
assert role_sent
assert finish_reason_count == 1
assert len(chunks)
assert "".join(chunks) == choice.message.content

View File

@ -0,0 +1,192 @@
import json
from typing import Dict, List, Optional
import openai
import pytest
from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE,
SEARCH_TOOL, WEATHER_TOOL)
# test: request a chat completion that should return tool calls, so we know they
# are parsable
@pytest.mark.asyncio
async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_TOOLS,
temperature=0,
max_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
tool_calls = chat_completion.choices[0].message.tool_calls
# make sure a tool call is present
assert choice.message.role == 'assistant'
assert tool_calls is not None
assert len(tool_calls) == 1
assert tool_calls[0].type == 'function'
assert tool_calls[0].function is not None
assert isinstance(tool_calls[0].id, str)
assert len(tool_calls[0].id) > 16
# make sure the weather tool was called (classic example) with arguments
assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"]
assert tool_calls[0].function.arguments is not None
assert isinstance(tool_calls[0].function.arguments, str)
# make sure the arguments parse properly
parsed_arguments = json.loads(tool_calls[0].function.arguments)
assert isinstance(parsed_arguments, Dict)
assert isinstance(parsed_arguments.get("city"), str)
assert isinstance(parsed_arguments.get("state"), str)
assert parsed_arguments.get("city") == "Dallas"
assert parsed_arguments.get("state") == "TX"
assert stop_reason == "tool_calls"
function_name: Optional[str] = None
function_args_str: str = ''
tool_call_id: Optional[str] = None
role_name: Optional[str] = None
finish_reason_count: int = 0
# make the same request, streaming
stream = await client.chat.completions.create(
model=model_name,
messages=MESSAGES_ASKING_FOR_TOOLS,
temperature=0,
max_tokens=100,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
async for chunk in stream:
assert chunk.choices[0].index == 0
if chunk.choices[0].finish_reason:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == 'tool_calls'
# if a role is being streamed make sure it wasn't already set to
# something else
if chunk.choices[0].delta.role:
assert not role_name or role_name == 'assistant'
role_name = 'assistant'
# if a tool call is streamed make sure there's exactly one
# (based on the request parameters
streamed_tool_calls = chunk.choices[0].delta.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id:
assert not tool_call_id
tool_call_id = tool_call.id
# if parts of the function start being streamed
if tool_call.function:
# if the function name is defined, set it. it should be streamed
# IN ENTIRETY, exactly one time.
if tool_call.function.name:
assert function_name is None
assert isinstance(tool_call.function.name, str)
function_name = tool_call.function.name
if tool_call.function.arguments:
assert isinstance(tool_call.function.arguments, str)
function_args_str += tool_call.function.arguments
assert finish_reason_count == 1
assert role_name == 'assistant'
assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16)
# validate the name and arguments
assert function_name == WEATHER_TOOL["function"]["name"]
assert function_name == tool_calls[0].function.name
assert isinstance(function_args_str, str)
# validate arguments
streamed_args = json.loads(function_args_str)
assert isinstance(streamed_args, Dict)
assert isinstance(streamed_args.get("city"), str)
assert isinstance(streamed_args.get("state"), str)
assert streamed_args.get("city") == "Dallas"
assert streamed_args.get("state") == "TX"
# make sure everything matches non-streaming except for ID
assert function_name == tool_calls[0].function.name
assert choice.message.role == role_name
assert choice.message.tool_calls[0].function.name == function_name
# compare streamed with non-streamed args Dict-wise, not string-wise
# because character-to-character comparison might not work e.g. the tool
# call parser adding extra spaces or something like that. we care about the
# dicts matching not byte-wise match
assert parsed_arguments == streamed_args
# test: providing tools and results back to model to get a non-tool response
# (streaming/not)
@pytest.mark.asyncio
async def test_tool_call_with_results(client: openai.AsyncOpenAI):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITH_TOOL_RESPONSE,
temperature=0,
max_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
assert choice.finish_reason != "tool_calls" # "stop" or "length"
assert choice.message.role == "assistant"
assert choice.message.tool_calls is None \
or len(choice.message.tool_calls) == 0
assert choice.message.content is not None
assert "98" in choice.message.content # the temperature from the response
stream = await client.chat.completions.create(
messages=MESSAGES_WITH_TOOL_RESPONSE,
temperature=0,
max_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
chunks: List[str] = []
finish_reason_count = 0
role_sent: bool = False
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert not role_sent
assert delta.role == "assistant"
role_sent = True
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == choice.finish_reason
assert not delta.tool_calls or len(delta.tool_calls) == 0
assert role_sent
assert finish_reason_count == 1
assert len(chunks)
assert "".join(chunks) == choice.message.content

215
tests/tool_use/utils.py Normal file
View File

@ -0,0 +1,215 @@
from typing import Dict, List
from openai.types.chat import (ChatCompletionMessageParam,
ChatCompletionToolParam)
from typing_extensions import TypedDict
from tests.utils import VLLM_PATH
class ServerConfig(TypedDict):
model: str
arguments: List[str]
# universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something.
ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"]
CONFIGS: Dict[str, ServerConfig] = {
"hermes": {
"model":
"NousResearch/Hermes-2-Pro-Llama-3-8B",
"arguments": [
"--tool-call-parser", "hermes", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja")
]
},
"mistral": {
"model":
"mistralai/Mistral-7B-Instruct-v0.3",
"arguments": [
"--tool-call-parser", "mistral", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"),
"--ignore-patterns=\"consolidated.safetensors\""
]
}
}
WEATHER_TOOL: ChatCompletionToolParam = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, "
"e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state "
"that the city is in, e.g. 'CA' which would "
"mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
}
}
}
}
SEARCH_TOOL: ChatCompletionToolParam = {
"type": "function",
"function": {
"name":
"web_search",
"description":
"Search the internet and get a summary of the top "
"10 webpages. Should only be used if you don't know "
"the answer to a user query, and the results are likely"
"to be able to be found with a web search",
"parameters": {
"type": "object",
"properties": {
"search_term": {
"type":
"string",
"description":
"The term to use in the search. This should"
"ideally be keywords to search for, not a"
"natural-language question"
}
},
"required": ["search_term"]
}
}
}
MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{
"role":
"system",
"content":
"You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
}, {
"role":
"user",
"content":
"Hi! How are you?"
}, {
"role":
"assistant",
"content":
"I'm doing great! How can I assist you?"
}, {
"role":
"user",
"content":
"Can you tell me a joke please?"
}]
MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas in Fahrenheit?"
}]
MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas in Fahrenheit?"
}, {
"role":
"assistant",
"tool_calls": [{
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"type": "function",
"function": {
"name":
WEATHER_TOOL["function"]["name"],
"arguments":
'{"city": "Dallas", "state": "TX", '
'"unit": "fahrenheit"}'
}
}]
}, {
"role":
"tool",
"tool_call_id":
"chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"content":
"The weather in Dallas is 98 degrees fahrenheit, with partly"
"cloudy skies and a low chance of rain."
}]
MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas and Orlando, Florida in "
"Fahrenheit?"
}]
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas and Orlando, Florida in "
"Fahrenheit?"
}, {
"role":
"assistant",
"tool_calls": [{
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"type": "function",
"function": {
"name":
WEATHER_TOOL["function"]["name"],
"arguments":
'{"city": "Dallas", "state": "TX", '
'"unit": "fahrenheit"}'
}
}, {
"id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
"type": "function",
"function": {
"name":
WEATHER_TOOL["function"]["name"],
"arguments":
'{"city": "Orlando", "state": "Fl", '
'"unit": "fahrenheit"}'
}
}]
}, {
"role":
"tool",
"tool_call_id":
"chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"content":
"The weather in Dallas TX is 98 degrees fahrenheit with mostly "
"cloudy skies and a chance of rain in the evening."
}, {
"role":
"tool",
"tool_call_id":
"chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
"content":
"The weather in Orlando FL is 78 degrees fahrenheit with clear"
"skies."
}]

View File

@ -1,23 +1,28 @@
import asyncio
import codecs
import json
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import lru_cache
from functools import lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
Mapping, Optional, Tuple, TypeVar, Union)
Mapping, Optional, Tuple, TypeVar, Union, cast)
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.chat import ChatCompletionContentPartImageParam
from openai.types.chat import (ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImageParam)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
from openai.types.chat import ChatCompletionContentPartTextParam
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
ChatCompletionContentPartTextParam)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict, TypeAdapter
from pydantic import ConfigDict
from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
@ -54,7 +59,8 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
CustomChatCompletionContentPartParam, ]
ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPartParam]
class CustomChatCompletionMessageParam(TypedDict, total=False):
@ -72,15 +78,33 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
same role.
"""
tool_call_id: Optional[str]
"""Tool call that this message is responding to."""
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
"""The tool calls generated by the model, such as function calls."""
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam]
# TODO: Make fields ReadOnly once mypy supports it
class ConversationMessage(TypedDict):
role: str
content: str
class ConversationMessage(TypedDict, total=False):
role: Required[str]
"""The role of the message's author."""
content: Optional[str]
"""The contents of the message"""
tool_call_id: Optional[str]
"""Tool call that this message is responding to."""
name: Optional[str]
"""The name of the function to call"""
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
"""The tool calls generated by the model, such as function calls."""
ModalityStr = Literal["image", "audio"]
@ -319,9 +343,11 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
return "\n".join(missing_placeholders + [text_prompt])
_TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam)
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
def _parse_chat_message_content_parts(
@ -336,10 +362,10 @@ def _parse_chat_message_content_parts(
for part in parts:
part_type = part["type"]
if part_type == "text":
text = _TextParser.validate_python(part)["text"]
text = _TextParser(part)["text"]
texts.append(text)
elif part_type == "image_url":
image_url = _ImageParser.validate_python(part)["image_url"]
image_url = _ImageParser(part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
@ -348,7 +374,7 @@ def _parse_chat_message_content_parts(
mm_parser.parse_image(image_url["url"])
elif part_type == "audio_url":
audio_url = _AudioParser.validate_python(part)["audio_url"]
audio_url = _AudioParser(part)["audio_url"]
mm_parser.parse_audio(audio_url["url"])
else:
@ -363,6 +389,11 @@ def _parse_chat_message_content_parts(
return [ConversationMessage(role=role, content=text_prompt)]
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker,
@ -371,16 +402,34 @@ def _parse_chat_message_content(
content = message.get("content")
if content is None:
return []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)]
content = []
elif isinstance(content, str):
content = [
ChatCompletionContentPartTextParam(type="text", text=content)
]
return _parse_chat_message_content_parts(
result = _parse_chat_message_content_parts(
role,
content, # type: ignore
mm_tracker,
)
for result_msg in result:
if role == 'assistant':
parsed_msg = _AssistantParser(message)
if "tool_calls" in parsed_msg:
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
elif role == "tool":
parsed_msg = _ToolParser(message)
if "tool_call_id" in parsed_msg:
result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
if "name" in message and isinstance(message["name"], str):
result_msg["name"] = message["name"]
return result
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
@ -428,6 +477,20 @@ def apply_chat_template(
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in conversation:
if (message["role"] == "assistant" and "tool_calls" in message
and isinstance(message["tool_calls"], list)):
for i in range(len(message["tool_calls"])):
args: str = message["tool_calls"][i]["function"]["arguments"]
parsed_args: Dict = json.loads(args)
message["tool_calls"][i]["function"]["arguments"] = parsed_args
prompt = tokenizer.apply_chat_template(
conversation=conversation,
chat_template=chat_template,

View File

@ -233,7 +233,7 @@ def mount_metrics(app: FastAPI):
metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
@ -283,11 +283,14 @@ async def show_version():
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump())
@ -422,7 +425,8 @@ async def init_app(
request_logger=request_logger,
chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser)
openai_serving_completion = OpenAIServingCompletion(
async_engine_client,
model_config,

View File

@ -163,6 +163,24 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine.")
parser.add_argument(
"--enable-auto-tool-choice",
action="store_true",
default=False,
help=
"Enable auto tool choice for supported models. Use --tool-call-parser"
"to specify which parser to use")
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["mistral", "hermes"],
default=None,
help=
"Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice.")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',

View File

@ -5,8 +5,9 @@ from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union
import torch
from openai.types.chat import ChatCompletionContentPartParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated
from typing_extensions import Annotated, Required, TypedDict
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
@ -35,6 +36,26 @@ assert _LONG_INFO.min == _MOCK_LONG_INFO.min
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API."""
role: Required[str]
"""The role of the message's author."""
content: Union[str, List[ChatCompletionContentPartParam]]
"""The contents of the message."""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
tool_call_id: Optional[str]
tool_calls: Optional[List[dict]]
class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
@ -145,8 +166,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"],
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
ChatCompletionNamedToolChoiceParam]] = "none"
# NOTE this will be ignored by VLLM -- the model determines the behavior
parallel_tool_calls: Optional[bool] = False
user: Optional[str] = None
# doc: begin-chat-completion-sampling-params
@ -328,6 +352,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
if isinstance(data, ValueError):
raise data
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
@ -339,21 +366,61 @@ class ChatCompletionRequest(OpenAIBaseModel):
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both
if guide_count > 1 and "tool_choice" in data and data[
"tool_choice"] != "none":
if guide_count > 1 and data.get("tool_choice",
"none") not in ("none", "auto"):
raise ValueError(
"You can only either use guided decoding or tools, not both.")
return data
@model_validator(mode="before")
@classmethod
def check_tool_choice(cls, data):
if "tool_choice" in data and data["tool_choice"] != "none":
if not isinstance(data["tool_choice"], dict):
raise ValueError("Currently only named tools are supported.")
def check_tool_usage(cls, data):
# if "tool_choice" is not specified but tools are provided,
# default to "auto" tool_choice
if "tool_choice" not in data and "tools" in data:
data["tool_choice"] = "auto"
# if "tool_choice" is specified -- validation
if "tool_choice" in data:
# ensure that if "tool choice" is specified, tools are present
if "tools" not in data or data["tools"] is None:
raise ValueError(
"When using `tool_choice`, `tools` must be set.")
# make sure that tool choice is either a named tool
# OR that it's set to "auto"
if data["tool_choice"] != "auto" and not isinstance(
data["tool_choice"], dict):
raise ValueError(
"`tool_choice` must either be a named tool or \"auto\". "
"`tool_choice=\"none\" is not supported.")
# ensure that if "tool_choice" is specified as an object,
# it matches a valid tool
if isinstance(data["tool_choice"], dict):
valid_tool = False
specified_function = data["tool_choice"]["function"]
if not specified_function:
raise ValueError(
"Incorrectly formatted `tool_choice`. Should be like "
"`{\"type\": \"function\","
" \"function\": {\"name\": \"my_function\"}}`")
specified_function_name = specified_function["name"]
if not specified_function_name:
raise ValueError(
"Incorrectly formatted `tool_choice`. Should be like "
"`{\"type\": \"function\", "
"\"function\": {\"name\": \"my_function\"}}`")
for tool in data["tools"]:
if tool["function"]["name"] == specified_function_name:
valid_tool = True
break
if not valid_tool:
raise ValueError(
"The tool specified in `tool_choice` does not match any"
" of the specified `tools`")
return data
@ -413,7 +480,7 @@ class CompletionRequest(OpenAIBaseModel):
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
description="If specified, the output will follow the JSON schema.",
)
guided_regex: Optional[str] = Field(
default=None,
@ -633,9 +700,41 @@ class ToolCall(OpenAIBaseModel):
function: FunctionCall
class DeltaFunctionCall(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
index: int
function: Optional[DeltaFunctionCall] = None
# the initial delta that gets sent once a new tool call is started;
class InitialDeltaToolCall(DeltaToolCall):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
index: int
class ExtractedToolCallInformation(BaseModel):
# indicate if tools were called
tools_called: bool
# extracted tool calls
tool_calls: List[ToolCall]
# content - per OpenAI spec, content AND tool calls can be returned rarely
# But some models will do this intentionally
content: Optional[str] = None
class ChatMessage(OpenAIBaseModel):
role: str
content: str
content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
@ -657,7 +756,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int
message: ChatMessage
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[str] = None
# per OpenAI spec this is the default
finish_reason: Optional[str] = "stop"
# not part of the OpenAI spec but included in vLLM for legacy reasons
stop_reason: Optional[Union[int, str]] = None
@ -674,7 +775,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
tool_calls: List[DeltaToolCall] = Field(default_factory=list)
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):

View File

@ -1,6 +1,8 @@
import asyncio
import json
import time
from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
Optional)
from typing import Sequence as GenericSequence
from typing import Union
@ -18,15 +20,18 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
FunctionCall, ToolCall, UsageInfo)
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
PromptAdapterPath,
TextTokensPrompt)
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
MistralToolParser,
ToolParser)
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
@ -38,19 +43,19 @@ logger = init_logger(__name__)
class OpenAIServingChat(OpenAIServing):
def __init__(
self,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
):
def __init__(self,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None):
super().__init__(async_engine_client=async_engine_client,
model_config=model_config,
served_model_names=served_model_names,
@ -60,10 +65,27 @@ class OpenAIServingChat(OpenAIServing):
return_tokens_as_token_ids=return_tokens_as_token_ids)
self.response_role = response_role
# If this is None we use the tokenizer's default chat template
self.use_tool_use_model_template = False
self.chat_template = load_chat_template(chat_template)
# set up tool use
self.enable_auto_tools: bool = enable_auto_tools
if self.enable_auto_tools:
logger.info(
"\"auto\" tool choice has been enabled please note that while"
" the parallel_tool_calls client option is preset for "
"compatibility reasons, it will be ignored.")
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.enable_auto_tools:
if tool_parser == "mistral":
self.tool_parser = MistralToolParser
elif tool_parser == "hermes":
self.tool_parser = Hermes2ProToolParser
else:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
async def create_chat_completion(
self,
request: ChatCompletionRequest,
@ -76,11 +98,10 @@ class OpenAIServingChat(OpenAIServing):
for the API specification. This API mimics the OpenAI
ChatCompletion API.
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret)
return error_check_ret
try:
@ -119,6 +140,20 @@ class OpenAIServingChat(OpenAIServing):
logger.error("Error in loading multi-modal data: %s", e)
return self.create_error_response(str(e))
# validation for OpenAI tools
# tool_choice = "required" is not supported
if request.tool_choice == "required":
return self.create_error_response(
"tool_choice = \"required\" is not supported!")
# "auto" tools requires --enable-auto-tool-choice
# and --tool-call-parser
if request.tool_choice == "auto" and not (
self.enable_auto_tools and self.tool_parser is not None):
return self.create_error_response(
"\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set")
request_id = f"chat-{random_uuid()}"
try:
guided_decode_logits_processor = (
@ -187,6 +222,7 @@ class OpenAIServingChat(OpenAIServing):
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer)
try:
return await self.chat_completion_full_generator(
request, result_generator, request_id, conversation, tokenizer)
@ -219,6 +255,9 @@ class OpenAIServingChat(OpenAIServing):
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
tool_parser: Optional[ToolParser] = self.tool_parser(
tokenizer) if self.tool_parser else None
try:
async for res in result_generator:
# We need to do it here, because if there are exceptions in
@ -228,6 +267,9 @@ class OpenAIServingChat(OpenAIServing):
# Send first response for each request.n (index) with
# the role
role = self.get_chat_request_role(request)
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for i in range(num_choices):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
@ -240,14 +282,18 @@ class OpenAIServingChat(OpenAIServing):
created=created_time,
choices=[choice_data],
model=model_name)
# if usage should be included
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats):
# if continuous usage stats are requested, add it
if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids)
usage = UsageInfo(prompt_tokens=prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens)
chunk.usage = usage
# otherwise don't
else:
chunk.usage = None
@ -257,7 +303,7 @@ class OpenAIServingChat(OpenAIServing):
# Send response to echo the input portion of the
# last message
if request.echo:
last_msg_content = ""
last_msg_content: Optional[str] = ""
if conversation and conversation[-1].get(
"content") and conversation[-1].get(
"role") == role:
@ -298,6 +344,7 @@ class OpenAIServingChat(OpenAIServing):
first_iteration = False
for output in res.outputs:
i = output.index
if finish_reason_sent[i]:
@ -320,19 +367,49 @@ class OpenAIServingChat(OpenAIServing):
logprobs = None
delta_text = output.text[len(previous_texts[i]):]
delta_message: Optional[DeltaMessage] = None
# handle streaming deltas for tools with named tool_choice
if (request.tool_choice and type(request.tool_choice) is
ChatCompletionNamedToolChoiceParam):
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text),
index=i)
])
# handle streaming deltas for tools with "auto" tool choice
elif (self._should_stream_with_auto_tool_parsing(request)
and tool_parser):
delta_message = (
tool_parser.extract_tool_calls_streaming(
previous_text=previous_texts[i],
current_text=output.text,
delta_text=delta_text,
previous_token_ids= \
output.token_ids[
:-1 * len(delta_token_ids)
],
current_token_ids=output.token_ids,
delta_token_ids=delta_token_ids
)
)
# handle streaming just a content delta
else:
delta_message = DeltaMessage(content=delta_text)
# set the previous values for the next iteration
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
if request.tool_choice and type(
request.tool_choice
) is ChatCompletionNamedToolChoiceParam:
delta_message = DeltaMessage(tool_calls=[
ToolCall(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text))
])
else:
delta_message = DeltaMessage(content=delta_text)
# if the message delta is None (e.g. because it was a
# "control token" for tool calls or the parser otherwise
# wasn't ready to send a token, then
# get the next token without streaming a chunk
if delta_message is None:
continue
if output.finish_reason is None:
# Send token-by-token response for each request.n
@ -348,6 +425,8 @@ class OpenAIServingChat(OpenAIServing):
created=created_time,
choices=[choice_data],
model=model_name)
# handle usage stats if requested & if continuous
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats):
@ -365,14 +444,55 @@ class OpenAIServingChat(OpenAIServing):
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# if the model is finished generating
else:
# check to make sure we haven't "forgotten" to stream
# any tokens that were generated but previously
# matched by partial json parsing
# only happens if we are NOT using guided decoding
if tool_parser:
index = len(
tool_parser.prev_tool_call_arr) - 1 if len(
tool_parser.prev_tool_call_arr) > 0 else 0
else:
index = 0
if self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output) and tool_parser:
# get the expected call based on partial JSON
# parsing which "autocompletes" the JSON
expected_call = json.dumps(
tool_parser.prev_tool_call_arr[index].get(
"arguments", {}))
# get what we've streamed so for for arguments
# for the current tool
actual_call = tool_parser.streamed_args_for_tool[
index]
# check to see if there's anything left to stream
remaining_call = expected_call.replace(
actual_call, "", 1)
# set that as a delta message
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(index=index,
function=DeltaFunctionCall(
arguments=remaining_call).
model_dump(exclude_none=True))
])
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=output.finish_reason,
finish_reason=output.finish_reason
if not (tool_parser
and len(tool_parser.prev_tool_call_arr))
else "tool_calls",
stop_reason=output.stop_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
@ -398,6 +518,8 @@ class OpenAIServingChat(OpenAIServing):
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
# once the final token is handled, if stream_options.include_usage
# is sent, send the usage
if (request.stream_options
and request.stream_options.include_usage):
final_usage = UsageInfo(
@ -419,6 +541,7 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
logger.error("error in chat completion stream generator: %s", e)
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
@ -463,8 +586,21 @@ class OpenAIServingChat(OpenAIServing):
else:
logprobs = None
if request.tool_choice and type(
# by default, tools are not used.
tools_called = False
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
if not (self.enable_auto_tools
or not self.tool_parser) and not isinstance(
request.tool_choice,
ChatCompletionNamedToolChoiceParam):
message = ChatMessage(role=role, content=output.text)
# if the request uses tools and specified a tool choice
elif request.tool_choice and type(
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
message = ChatMessage(
role=role,
content="",
@ -473,14 +609,47 @@ class OpenAIServingChat(OpenAIServing):
name=request.tool_choice.function.name,
arguments=output.text))
])
tools_called = True
# if the request doesn't use tool choice
# OR specifies to not use a tool
elif not request.tool_choice or request.tool_choice == "none":
message = ChatMessage(role=role, content=output.text)
# handle when there are tools and tool choice is auto
elif request.tools and (
request.tool_choice == "auto"
or request.tool_choice is None) and self.enable_auto_tools \
and self.tool_parser:
tool_parser = self.tool_parser(tokenizer)
tool_call_info = tool_parser.extract_tool_calls(output.text)
tools_called = tool_call_info.tools_called
if tool_call_info.tools_called:
message = ChatMessage(role=role,
content=tool_call_info.content,
tool_calls=tool_call_info.tool_calls)
else:
# FOR NOW make it a chat message; we will have to detect
# the type to make it later.
message = ChatMessage(role=role, content=output.text)
# undetermined case that is still important to handle
else:
logger.error(
"Error in chat_completion_full_generator - cannot determine"
" if tools should be extracted. Returning a standard chat "
"completion.")
message = ChatMessage(role=role, content=output.text)
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=message,
logprobs=logprobs,
finish_reason=output.finish_reason,
finish_reason="tool_calls" if tools_called else
output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason)
choices.append(choice_data)
@ -488,10 +657,11 @@ class OpenAIServingChat(OpenAIServing):
last_msg_content = ""
if conversation and conversation[-1].get(
"content") and conversation[-1].get("role") == role:
last_msg_content = conversation[-1]["content"]
last_msg_content = conversation[-1]["content"] or ""
for choice in choices:
full_message = last_msg_content + choice.message.content
full_message = last_msg_content + (choice.message.content
or "")
choice.message.content = full_message
num_prompt_tokens = len(final_res.prompt_token_ids)
@ -574,3 +744,38 @@ class OpenAIServingChat(OpenAIServing):
))
return ChatCompletionLogProbs(content=logprobs_content)
def _should_stream_with_auto_tool_parsing(self,
request: ChatCompletionRequest):
"""
Utility function to check if streamed tokens should go through the tool
call parser that was configured.
We only want to do this IF user-provided tools are set, a tool parser
is configured, "auto" tool choice is enabled, and the request's tool
choice field indicates that "auto" tool choice should be used.
"""
return (request.tools and self.tool_parser and self.enable_auto_tools
and request.tool_choice in ['auto', None])
def _should_check_for_unstreamed_tool_arg_tokens(
self,
delta_message: Optional[DeltaMessage],
output: CompletionOutput,
) -> bool:
"""
Check to see if we should check for unstreamed tool arguments tokens.
This is only applicable when auto tool parsing is enabled, the delta
is a tool call with arguments.
"""
# yapf: disable
return bool(
# if there is a delta message that includes tool calls which
# include a function that has arguments
self.enable_auto_tools and self.tool_parser and delta_message
and delta_message.tool_calls and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None
and output.finish_reason is not None
)

View File

@ -43,7 +43,11 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger=request_logger)
# If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template)
# the list of commonly-used chat template names for HF named templates
hf_chat_templates: List[str] = ['default', 'tool_use']
self.chat_template = chat_template \
if chat_template in hf_chat_templates \
else load_chat_template(chat_template)
async def create_tokenize(
self,

View File

@ -0,0 +1,5 @@
from .abstract_tool_parser import ToolParser
from .hermes_tool_parser import Hermes2ProToolParser
from .mistral_tool_parser import MistralToolParser
__all__ = ["ToolParser", "Hermes2ProToolParser", "MistralToolParser"]

View File

@ -0,0 +1,58 @@
from typing import Dict, List, Sequence, Union
from vllm.entrypoints.openai.protocol import (DeltaMessage,
ExtractedToolCallInformation)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
class ToolParser:
"""
Abstract ToolParser class that should not be used directly. Provided
properties and methods should be used in
derived classes.
"""
def __init__(self, tokenizer: AnyTokenizer):
self.prev_tool_call_arr: List[Dict] = []
# the index of the tool call that is currently being parsed
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.current_tool_initial_sent: bool = False
self.streamed_args_for_tool: List[str] = []
self.model_tokenizer = tokenizer
def extract_tool_calls(self,
model_output: str) -> ExtractedToolCallInformation:
"""
Static method that should be implemented for extracting tool calls from
a complete model-generated string.
Used for non-streaming responses where we have the entire model response
available before sending to the client.
Static because it's stateless.
"""
raise NotImplementedError(
"AbstractToolParser.extract_tool_calls has not been implemented!")
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""
Instance method that should be implemented for extracting tool calls
from an incomplete response; for use when handling tool calls and
streaming. Has to be an instance method because it requires state -
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise NotImplementedError(
"AbstractToolParser.extract_tool_calls_streaming has not been "
"implemented!")

View File

@ -0,0 +1,344 @@
import json
import re
from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
InitialDeltaToolCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
logger = init_logger(__name__)
class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
logger.error(
"Detected Mistral tokenizer when using a Hermes model")
self.model_tokenizer = self.model_tokenizer.tokenizer
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent = False
self.current_tool_initial_sent: bool = False
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_regex = re.compile(
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
self.scratch_pad_regex = re.compile(
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_call_start_token_id: int = self.model_tokenizer.vocab[
self.tool_call_start_token]
self.tool_call_end_token_id: int = self.model_tokenizer.vocab[
self.tool_call_end_token]
if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
def extract_tool_calls(self,
model_output: str) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
else:
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = (
self.tool_call_regex.findall(model_output))
# load the JSON, and then use it to build the Function and
# Tool Call
raw_function_calls = [
json.loads(match[0] if match[0] else match[1])
for match in function_call_tuples
]
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"])))
for function_call in raw_function_calls
]
content = model_output[:model_output.
find(self.tool_call_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None)
except Exception as e:
logger.error("Error in extracting tool call from response %s",
e)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# check to see if we should be streaming a tool call - is there a
if self.tool_call_start_token_id not in current_token_ids:
logger.debug("No tool call tokens found!")
return DeltaMessage(content=delta_text)
try:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count = previous_token_ids.count(
self.tool_call_start_token_id)
prev_tool_end_count = previous_token_ids.count(
self.tool_call_end_token_id)
cur_tool_start_count = current_token_ids.count(
self.tool_call_start_token_id)
cur_tool_end_count = current_token_ids.count(
self.tool_call_end_token_id)
# case: if we're generating text, OR rounding out a tool call
if (cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count):
logger.debug("Generating text content! skipping tool parsing.")
if delta_text != self.tool_call_end_token:
return DeltaMessage(content=delta_text)
# case: if tool open & close tag counts don't match, we're doing
# imaginary "else" block here
# something with tools with this diff.
# flags for partial JSON parting. exported constants from
# "Allow" are handled via BIT MASK
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
# case -- we're starting a new tool call
if (cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count > prev_tool_start_count):
if len(delta_token_ids) > 1:
tool_call_portion = current_text.split(
self.tool_call_start_token)[-1]
else:
tool_call_portion = None
delta = None
text_portion = None
# set cursors and state appropriately
self.current_tool_id += 1
self.current_tool_name_sent = False
self.current_tool_initial_sent = False
self.streamed_args_for_tool.append("")
logger.debug("Starting on a new tool %s", self.current_tool_id)
# case -- we're updating an existing tool call
elif (cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count == prev_tool_start_count):
# get the portion of the text that's the tool call
tool_call_portion = current_text.split(
self.tool_call_start_token)[-1]
text_portion = None
# case -- the current tool call is being closed.
elif (cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count > prev_tool_end_count):
diff = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments")
if diff:
diff = json.dumps(diff).replace(
self.streamed_args_for_tool[self.current_tool_id], "")
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s", diff)
self.streamed_args_for_tool[self.current_tool_id] \
+= diff
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(
exclude_none=True))
])
# case -- otherwise we're just generating text
else:
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
delta = DeltaMessage(tool_calls=[], content=text)
return delta
try:
current_tool_call = partial_json_parser.loads(
tool_call_portion or "{}",
flags) if tool_call_portion else None
logger.debug("Parsed tool call %s", current_tool_call)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# case - we haven't sent the initial delta with the tool call ID
# (it will be sent)
if not self.current_tool_initial_sent:
self.current_tool_initial_sent = True
return DeltaMessage(tool_calls=[
InitialDeltaToolCall(
index=self.current_tool_id).model_dump(
exclude_none=True)
])
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
elif not self.current_tool_name_sent:
function_name: Union[str, None] = current_tool_call.get("name")
if function_name:
self.current_tool_name_sent = True
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
else:
return None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if tool_call_portion is None:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta = DeltaMessage(content=delta_text) \
if text_portion is not None else None
return delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger.debug("Trying to parse current tool call with ID %s",
self.current_tool_id)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments = (
self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments)
logger.debug("against new ones: %s", cur_arguments)
# case -- no arguments have been created yet. skip sending a delta.
if not cur_arguments and not prev_arguments:
logger.debug("Skipping text %s - no arguments", delta_text)
delta = None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif not cur_arguments and prev_arguments:
logger.error("should be impossible to have arguments reset "
"mid-call. skipping streaming anything.")
delta = None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
logger.debug("finding %s in %s", delta_text,
cur_arguments_json)
# get the location where previous args differ from current
args_delta_start_loc = cur_arguments_json.index(delta_text) \
+ len(delta_text)
# use that to find the actual delta
arguments_delta = cur_arguments_json[:args_delta_start_loc]
logger.debug("First tokens in arguments received: %s",
arguments_delta)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[self.current_tool_id] \
+= arguments_delta
# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
logger.debug("Searching for diff between\n%s", cur_args_json)
logger.debug("and\n%s", prev_args_json)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
logger.debug("got argument diff %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[self.current_tool_id] \
+= argument_diff
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[self.current_tool_id] = \
current_tool_call
else:
self.prev_tool_call_arr.append(current_tool_call)
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
return None # do not stream a delta. skip this token ID.

View File

@ -0,0 +1,293 @@
import json
import re
from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
InitialDeltaToolCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
logger = init_logger(__name__)
class MistralToolParser(ToolParser):
"""
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
examples/tool_chat_template_mistral.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser gmistral are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
self.model_tokenizer = self.model_tokenizer.tokenizer
else:
logger.info("Non-Mistral tokenizer detected when using a Mistral "
"model...")
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.current_tool_initial_sent: bool = False
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.model_tokenizer.vocab[self.bot_token]
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
def extract_tool_calls(self,
model_output: str) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response. Requires
find-and-replacing single quotes with double quotes for JSON parsing,
make sure your tool call arguments don't ever include quotes!
"""
# case -- if a tool call token is not present, return a text response
if self.bot_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
# use a regex to find the tool call. remove the BOT token
# and make sure to replace single quotes with double quotes
raw_tool_call = self.tool_call_regex.findall(
model_output.replace(self.bot_token, ""))[0]
# load the JSON, and then use it to build the Function and
# Tool Call
function_call_arr = json.loads(raw_tool_call)
tool_calls: List[ToolCall] = [
ToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(raw_function_call["arguments"])))
for raw_function_call in function_call_arr
]
# get any content before the tool call
content = model_output.split(self.bot_token)[0]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if len(content) > 0 else None)
except Exception as e:
logger.error("Error in extracting tool call from response: %s", e)
print("ERROR", e)
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if self.bot_token_id not in current_token_ids:
return DeltaMessage(content=delta_text)
# if the tool call token ID IS in the tokens generated so far, that
# means we're parsing as tool calls now
# handle if we detected the BOT token which means the start of tool
# calling
if (self.bot_token_id in delta_token_ids
and len(delta_token_ids) == 1):
# if it's the only token, return None, so we don't send a chat
# completion any don't send a control token
return None
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
# replace BOT token with empty string, and convert single quotes
# to double to allow parsing as JSON since mistral uses single
# quotes instead of double for tool calls
parsable_arr = current_text.split(self.bot_token)[1]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try:
tool_call_arr: List[Dict] = partial_json_parser.loads(
parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (len(tool_call_arr) > 0
and len(tool_call_arr) > self.current_tool_id + 1):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
diff: Union[str, None] = current_tool_call.get("arguments")
if diff:
diff = json.dumps(diff).replace(
self.streamed_args_for_tool[self.current_tool_id],
"")
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.current_tool_initial_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# case: update an existing tool - this is handled below
# if the current tool initial data incl. the id, type=function
# and idx not sent, send that
if not self.current_tool_initial_sent:
self.current_tool_initial_sent = True
delta = DeltaMessage(tool_calls=[
InitialDeltaToolCall(
index=self.current_tool_id).model_dump(
exclude_none=True)
])
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
cur_arguments = current_tool_call.get("arguments")
new_text = delta_text.replace("\'", "\"")
if not cur_arguments and not prev_arguments:
delta = None
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset "
"mid-arguments")
delta = None
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
logger.debug("finding %s in %s", new_text,
cur_arguments_json)
arguments_delta = cur_arguments_json[:cur_arguments_json.
index(new_text) +
len(new_text)]
logger.debug("First tokens in arguments received: %s",
arguments_delta)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
logger.debug("Searching for diff between \n%s\n%s",
cur_args_json, prev_args_json)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
else:
# try parsing it with regular JSON - if it works we're
# at the end, and we need to send the difference between
# tokens streamed so far and the valid JSON
delta = None
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None

View File

@ -0,0 +1,87 @@
def find_common_prefix(s1: str, s2: str) -> str:
"""
Finds a common prefix that is shared between two strings, if there is one.
Order of arguments is NOT important.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely.
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
'{"fruit": "ap'
"""
prefix = ''
min_length = min(len(s1), len(s2))
for i in range(0, min_length):
if s1[i] == s2[i]:
prefix += s1[i]
else:
break
return prefix
def find_common_suffix(s1: str, s2: str) -> str:
"""
Finds a common suffix shared between two strings, if there is one. Order of
arguments is NOT important.
Stops when the suffix ends OR it hits an alphanumeric character
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
"""
suffix = ''
min_length = min(len(s1), len(s2))
for i in range(1, min_length + 1):
if s1[-i] == s2[-i] and not s1[-i].isalnum():
suffix = s1[-i] + suffix
else:
break
return suffix
def extract_intermediate_diff(curr: str, old: str) -> str:
"""
Given two strings, extract the difference in the middle between two strings
that are known to have a common prefix and/or suffix.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely. The order of arguments IS
important - the new version of the partially-parsed JSON must be the first
argument, and the secnod argument must be from the previous generation.
What it returns, is tokens that should be streamed to the client.
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
-> 'ple'
"""
suffix = find_common_suffix(curr, old)
old = old[::-1].replace(suffix[::-1], '', 1)[::-1]
prefix = find_common_prefix(curr, old)
diff = curr
if len(suffix):
diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1]
if len(prefix):
# replace the prefix only once in case it's mirrored
diff = diff.replace(prefix, '', 1)
return diff
def find_all_indices(string, substring):
"""
Find all (starting) indices of a substring in a given string. Useful for
tool call extraction
"""
indices = []
index = -1
while True:
index = string.find(substring, index + 1)
if index == -1:
break
indices.append(index)
return indices

View File

@ -59,8 +59,9 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest,
if type(request) is CompletionRequest:
return request
# user has chosen to not use any tool
if request.tool_choice == "none":
# user has chosen to not use any tool,
# OR is allowing the model to choose a tool.
if request.tool_choice == "none" or request.tool_choice == "auto":
return request
# user has chosen to use a named tool

View File

@ -8,8 +8,9 @@ from typing import Tuple, Union
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
@ -101,16 +102,30 @@ def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest,
GuidedDecodingRequest]
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
# if the request is a chat completion request, AND the tool choice is a
# named tool choice, do guided decoding
# using that tool as the JSON schema
if isinstance(request, ChatCompletionRequest) and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam):
# Guided generation for tools/functions parameters
if request.tool_choice.type == "function":
for tool in request.tools:
if (tool.type == "function" and tool.function.name
== request.tool_choice.function.name):
json = json_dumps(tool.function.parameters, sort_keys=True)
return json, GuidedDecodingMode.JSON
return None, None
if request.guided_json:
json = request.guided_json
if isinstance(json, dict):
elif request.guided_json:
if isinstance(request.guided_json, dict):
# turn dict into hashable string
json = json_dumps(json)
elif isinstance(json, BaseModel):
json = json_dumps(request.guided_json)
elif isinstance(request.guided_json, BaseModel):
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
json = str(json.__signature__)
json = str(request.guided_json.__signature__)
else:
json = request.guided_json
return json, GuidedDecodingMode.JSON
elif request.guided_regex:
return request.guided_regex, GuidedDecodingMode.REGEX