[Feature] OpenAI-Compatible Tools API + Streaming for Hermes & Mistral models (#5649)
Co-authored-by: constellate <constellate@1-ai-appserver-staging.codereach.com> Co-authored-by: Kyle Mistele <kyle@constellate.ai>
This commit is contained in:
parent
561d6f8077
commit
e02ce498be
@ -92,6 +92,7 @@ steps:
|
||||
- pytest -v -s entrypoints/openai
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
@ -271,6 +272,15 @@ steps:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- bash ./run-tests.sh -c configs/models-small.txt -t 1
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 20 min
|
||||
fast_check: false
|
||||
mirror_hardwares: [ amd ]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- pytest -v -s tool_use
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
|
@ -110,6 +110,14 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
|
||||
:func: create_parser_for_docs
|
||||
:prog: vllm serve
|
||||
```
|
||||
## Tool Calling in the Chat Completion API
|
||||
### Named Function Calling
|
||||
vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is
|
||||
enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a
|
||||
high-quality one.
|
||||
|
||||
To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and
|
||||
specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.
|
||||
|
||||
### Config file
|
||||
|
||||
@ -140,10 +148,52 @@ The order of priorities is `command line > config file values > defaults`.
|
||||
## Tool calling in the chat completion API
|
||||
vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap.
|
||||
|
||||
To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter.
|
||||
|
||||
It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.**
|
||||
It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt.
|
||||
|
||||
vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
|
||||
|
||||
Please refer to the OpenAI API reference documentation for more information.
|
||||
|
||||
### Automatic Function Calling
|
||||
To enable this feature, you should set the following flags:
|
||||
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
|
||||
deems appropriate.
|
||||
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. Additional tool parsers
|
||||
will continue to be added in the future.
|
||||
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
|
||||
that contain previously generated tool calls. Hermes and Mistral models have tool-compatible chat templates in their
|
||||
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
|
||||
template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates)
|
||||
from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json)
|
||||
|
||||
If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template!
|
||||
|
||||
#### Hermes Models
|
||||
All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported.
|
||||
* `NousResearch/Hermes-2-Pro-*`
|
||||
* `NousResearch/Hermes-2-Theta-*`
|
||||
* `NousResearch/Hermes-3-*`
|
||||
|
||||
|
||||
_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge
|
||||
step in their creation_.
|
||||
|
||||
Flags: `--tool-call-parser hermes`
|
||||
|
||||
#### Mistral Models
|
||||
Supported models:
|
||||
* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed)
|
||||
* Additional mistral function-calling models are compatible as well.
|
||||
|
||||
Known issues:
|
||||
1. Mistral 7B struggles to generate parallel tool calls correctly.
|
||||
2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
|
||||
much shorter than what vLLM generates. Since an exception is thrown when this condition
|
||||
is not met, the following additional chat templates are provided:
|
||||
|
||||
* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that
|
||||
it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits)
|
||||
* `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt
|
||||
when tools are provided, that results in much better reliability when working with parallel tool calling.
|
||||
|
||||
|
||||
Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
|
||||
|
162
examples/openai_chat_completion_client_with_tools.py
Normal file
162
examples/openai_chat_completion_client_with_tools.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""
|
||||
Set up this example by starting a vLLM OpenAI-compatible server with tool call
|
||||
options enabled. For example:
|
||||
|
||||
IMPORTANT: for mistral, you must use one of the provided mistral tool call
|
||||
templates, or your own - the model default doesn't work for tool calls with vLLM
|
||||
See the vLLM docs on OpenAI server & tool calling for more details.
|
||||
|
||||
vllm serve --model mistralai/Mistral-7B-Instruct-v0.3 \
|
||||
--chat-template examples/tool_chat_template_mistral.jinja \
|
||||
--enable-auto-tool-choice --tool-call-parser mistral
|
||||
|
||||
OR
|
||||
vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \
|
||||
--chat-template examples/tool_chat_template_hermes.jinja \
|
||||
--enable-auto-tool-choice --tool-call-parser hermes
|
||||
"""
|
||||
import json
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'San Francisco'"
|
||||
},
|
||||
"state": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"the two-letter abbreviation for the state that the city is"
|
||||
" in, e.g. 'CA' which would mean 'California'"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["city", "state", "unit"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
|
||||
}]
|
||||
|
||||
chat_completion = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools)
|
||||
|
||||
print("Chat completion results:")
|
||||
print(chat_completion)
|
||||
print("\n\n")
|
||||
|
||||
tool_calls_stream = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
stream=True)
|
||||
|
||||
chunks = []
|
||||
for chunk in tool_calls_stream:
|
||||
chunks.append(chunk)
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
print(chunk.choices[0].delta.tool_calls[0])
|
||||
else:
|
||||
print(chunk.choices[0].delta)
|
||||
|
||||
arguments = []
|
||||
tool_call_idx = -1
|
||||
for chunk in chunks:
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
|
||||
if tool_call.index != tool_call_idx:
|
||||
if tool_call_idx >= 0:
|
||||
print(
|
||||
f"streamed tool call arguments: {arguments[tool_call_idx]}"
|
||||
)
|
||||
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
|
||||
arguments.append("")
|
||||
if tool_call.id:
|
||||
print(f"streamed tool call id: {tool_call.id} ")
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
print(f"streamed tool call name: {tool_call.function.name}")
|
||||
|
||||
if tool_call.function.arguments:
|
||||
arguments[tool_call_idx] += tool_call.function.arguments
|
||||
|
||||
if len(arguments):
|
||||
print(f"streamed tool call arguments: {arguments[-1]}")
|
||||
|
||||
print("\n\n")
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"tool_calls": chat_completion.choices[0].message.tool_calls
|
||||
})
|
||||
|
||||
|
||||
# Now, simulate a tool call
|
||||
def get_current_weather(city: str, state: str, unit: 'str'):
|
||||
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
|
||||
"partly cloudly, with highs in the 90's.")
|
||||
|
||||
|
||||
available_tools = {"get_current_weather": get_current_weather}
|
||||
|
||||
completion_tool_calls = chat_completion.choices[0].message.tool_calls
|
||||
for call in completion_tool_calls:
|
||||
tool_to_call = available_tools[call.function.name]
|
||||
args = json.loads(call.function.arguments)
|
||||
result = tool_to_call(**args)
|
||||
print(result)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": result,
|
||||
"tool_call_id": call.id,
|
||||
"name": call.function.name
|
||||
})
|
||||
|
||||
chat_completion_2 = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
stream=False)
|
||||
print("\n\n")
|
||||
print(chat_completion_2)
|
129
examples/tool_chat_template_hermes.jinja
Normal file
129
examples/tool_chat_template_hermes.jinja
Normal file
@ -0,0 +1,129 @@
|
||||
{%- macro json_to_python_type(json_spec) %}
|
||||
{%- set basic_type_map = {
|
||||
"string": "str",
|
||||
"number": "float",
|
||||
"integer": "int",
|
||||
"boolean": "bool"
|
||||
} %}
|
||||
|
||||
{%- if basic_type_map[json_spec.type] is defined %}
|
||||
{{- basic_type_map[json_spec.type] }}
|
||||
{%- elif json_spec.type == "array" %}
|
||||
{{- "list[" + json_to_python_type(json_spec|items) + "]" }}
|
||||
{%- elif json_spec.type == "object" %}
|
||||
{%- if json_spec.additionalProperties is defined %}
|
||||
{{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']' }}
|
||||
{%- else %}
|
||||
{{- "dict" }}
|
||||
{%- endif %}
|
||||
{%- elif json_spec.type is iterable %}
|
||||
{{- "Union[" }}
|
||||
{%- for t in json_spec.type %}
|
||||
{{- json_to_python_type({"type": t}) }}
|
||||
{%- if not loop.last %}
|
||||
{{- "," }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "]" }}
|
||||
{%- else %}
|
||||
{{- "Any" }}
|
||||
{%- endif %}
|
||||
{%- endmacro %}
|
||||
|
||||
|
||||
{{- bos_token }}
|
||||
{{- "<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> " }}
|
||||
{%- if tools is iterable and tools | length > 0 %}
|
||||
{%- for tool in tools %}
|
||||
{%- if tool.function is defined %}
|
||||
{%- set tool = tool.function %}
|
||||
{%- endif %}
|
||||
{{- '{"type": "function", "function": ' }}
|
||||
{{- '{"name": "' + tool.name + '", ' }}
|
||||
{{- '"description": "' + tool.name + '(' }}
|
||||
{%- for param_name, param_fields in tool.parameters.properties|items %}
|
||||
{{- param_name + ": " + json_to_python_type(param_fields) }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- ")" }}
|
||||
{%- if tool.return is defined %}
|
||||
{{- " -> " + json_to_python_type(tool.return) }}
|
||||
{%- endif %}
|
||||
{{- " - " + tool.description + "\n\n" }}
|
||||
{%- for param_name, param_fields in tool.parameters.properties|items %}
|
||||
{%- if loop.first %}
|
||||
{{- " Args:\n" }}
|
||||
{%- endif %}
|
||||
{{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }}
|
||||
{%- endfor %}
|
||||
{%- if tool.return is defined and tool.return.description is defined %}
|
||||
{{- "\n Returns:\n " + tool.return.description }}
|
||||
{%- endif %}
|
||||
{{- '"' }}
|
||||
{{- ', "parameters": ' }}
|
||||
{%- if tool.parameters.properties | length == 0 %}
|
||||
{{- "{}" }}
|
||||
{%- else %}
|
||||
{{- tool.parameters|tojson }}
|
||||
{%- endif %}
|
||||
{{- "}" }}
|
||||
{%- if not loop.last %}
|
||||
{{- "\n" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- " </tools>" }}
|
||||
{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}}
|
||||
' }}
|
||||
{{- "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
" }}
|
||||
{{- "<tool_call>
|
||||
" }}
|
||||
{{- '{"name": <function-name>, "arguments": <args-dict>}
|
||||
' }}
|
||||
{{- '</tool_call><|im_end|>' }}
|
||||
{%- for message in messages %}
|
||||
{%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" and message.tool_calls is defined %}
|
||||
{{- '<|im_start|>' + message.role }}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{{- '\n<tool_call>\n' }}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '{' }}
|
||||
{{- '"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '"}' }}
|
||||
{{- ', ' }}
|
||||
{%- if tool_call.arguments is defined %}
|
||||
{{- '"arguments": ' }}
|
||||
{{- tool_call.arguments|tojson }}
|
||||
{%- endif %}
|
||||
{{- '\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.previtem and loop.previtem.role != "tool" %}
|
||||
{{- '<|im_start|>tool\n' }}
|
||||
{%- endif %}
|
||||
{{- '<tool_response>\n' }}
|
||||
{{- message.content }}
|
||||
{%- if not loop.last %}
|
||||
{{- '\n</tool_response>\n' }}
|
||||
{%- else %}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- endif %}
|
||||
{%- if not loop.last and loop.nextitem.role != "tool" %}
|
||||
{{- '<|im_end|>' }}
|
||||
{%- elif loop.last %}
|
||||
{{- '<|im_end|>' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
86
examples/tool_chat_template_mistral.jinja
Normal file
86
examples/tool_chat_template_mistral.jinja
Normal file
@ -0,0 +1,86 @@
|
||||
{%- if messages[0]["role"] == "system" %}
|
||||
{%- set system_message = messages[0]["content"] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
{%- if not tools is defined %}
|
||||
{%- set tools = none %}
|
||||
{%- endif %}
|
||||
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
|
||||
|
||||
{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}
|
||||
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
|
||||
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- for message in loop_messages %}
|
||||
{%- if message["role"] == "user" %}
|
||||
{%- if tools is not none and (message == user_messages[-1]) %}
|
||||
{{- "[AVAILABLE_TOOLS] [" }}
|
||||
{%- for tool in tools %}
|
||||
{%- set tool = tool.function %}
|
||||
{{- '{"type": "function", "function": {' }}
|
||||
{%- for key, val in tool.items() if key != "return" %}
|
||||
{%- if val is string %}
|
||||
{{- '"' + key + '": "' + val + '"' }}
|
||||
{%- else %}
|
||||
{{- '"' + key + '": ' + val|tojson }}
|
||||
{%- endif %}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "}}" }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- else %}
|
||||
{{- "]" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "[/AVAILABLE_TOOLS]" }}
|
||||
{%- endif %}
|
||||
{%- if loop.last and system_message is defined %}
|
||||
{{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }}
|
||||
{%- else %}
|
||||
{{- "[INST] " + message["content"] + "[/INST]" }}
|
||||
{%- endif %}
|
||||
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
|
||||
{%- if message.tool_calls is defined %}
|
||||
{%- set tool_calls = message.tool_calls %}
|
||||
{%- else %}
|
||||
{%- set tool_calls = message.content %}
|
||||
{%- endif %}
|
||||
{{- "[TOOL_CALLS] [" }}
|
||||
{%- for tool_call in tool_calls %}
|
||||
{%- set out = tool_call.function|tojson %}
|
||||
{{- out[:-1] }}
|
||||
{%- if not tool_call.id is defined or tool_call.id|length < 9 %}
|
||||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
|
||||
{%- endif %}
|
||||
{{- ', "id": "' + tool_call.id[-9:] + '"}' }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- else %}
|
||||
{{- "]" + eos_token }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- elif message["role"] == "assistant" %}
|
||||
{{- " " + message["content"] + eos_token }}
|
||||
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
|
||||
{%- if message.content is defined and message.content.content is defined %}
|
||||
{%- set content = message.content.content %}
|
||||
{%- else %}
|
||||
{%- set content = message.content %}
|
||||
{%- endif %}
|
||||
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
|
||||
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
|
||||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
|
||||
{%- endif %}
|
||||
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
|
||||
{%- else %}
|
||||
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
94
examples/tool_chat_template_mistral_parallel.jinja
Normal file
94
examples/tool_chat_template_mistral_parallel.jinja
Normal file
@ -0,0 +1,94 @@
|
||||
{%- if messages[0]["role"] == "system" %}
|
||||
{%- set system_message = messages[0]["content"] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
{%- if not tools is defined %}
|
||||
{%- set tools = none %}
|
||||
{%- endif %}
|
||||
{%- if tools is defined %}
|
||||
{%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %}
|
||||
{%- if system_message is defined %}
|
||||
{%- set system_message = parallel_tool_prompt + "\n\n" + system_message %}
|
||||
{%- else %}
|
||||
{%- set system_message = parallel_tool_prompt %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
|
||||
|
||||
{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}
|
||||
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
|
||||
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- for message in loop_messages %}
|
||||
{%- if message["role"] == "user" %}
|
||||
{%- if tools is not none and (message == user_messages[-1]) %}
|
||||
{{- "[AVAILABLE_TOOLS] [" }}
|
||||
{%- for tool in tools %}
|
||||
{%- set tool = tool.function %}
|
||||
{{- '{"type": "function", "function": {' }}
|
||||
{%- for key, val in tool.items() if key != "return" %}
|
||||
{%- if val is string %}
|
||||
{{- '"' + key + '": "' + val + '"' }}
|
||||
{%- else %}
|
||||
{{- '"' + key + '": ' + val|tojson }}
|
||||
{%- endif %}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "}}" }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- else %}
|
||||
{{- "]" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "[/AVAILABLE_TOOLS]" }}
|
||||
{%- endif %}
|
||||
{%- if loop.last and system_message is defined %}
|
||||
{{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }}
|
||||
{%- else %}
|
||||
{{- "[INST] " + message["content"] + "[/INST]" }}
|
||||
{%- endif %}
|
||||
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
|
||||
{%- if message.tool_calls is defined %}
|
||||
{%- set tool_calls = message.tool_calls %}
|
||||
{%- else %}
|
||||
{%- set tool_calls = message.content %}
|
||||
{%- endif %}
|
||||
{{- "[TOOL_CALLS] [" }}
|
||||
{%- for tool_call in tool_calls %}
|
||||
{%- set out = tool_call.function|tojson %}
|
||||
{{- out[:-1] }}
|
||||
{%- if not tool_call.id is defined or tool_call.id|length < 9 %}
|
||||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
|
||||
{%- endif %}
|
||||
{{- ', "id": "' + tool_call.id[-9:] + '"}' }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- else %}
|
||||
{{- "]" + eos_token }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- elif message["role"] == "assistant" %}
|
||||
{{- " " + message["content"] + eos_token }}
|
||||
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
|
||||
{%- if message.content is defined and message.content.content is defined %}
|
||||
{%- set content = message.content.content %}
|
||||
{%- else %}
|
||||
{%- set content = message.content %}
|
||||
{%- endif %}
|
||||
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
|
||||
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
|
||||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
|
||||
{%- endif %}
|
||||
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
|
||||
{%- else %}
|
||||
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
@ -20,6 +20,7 @@ lm-format-enforcer == 0.10.6
|
||||
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
|
||||
typing_extensions >= 4.10
|
||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||
partial-json-parser # used for parsing partial JSON outputs
|
||||
pyzmq
|
||||
msgspec
|
||||
gguf == 0.9.1
|
||||
|
0
tests/tool_use/__init__.py
Normal file
0
tests/tool_use/__init__.py
Normal file
32
tests/tool_use/conftest.py
Normal file
32
tests/tool_use/conftest.py
Normal file
@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
from .utils import ARGS, CONFIGS, ServerConfig
|
||||
|
||||
|
||||
# for each server config, download the model and return the config
|
||||
@pytest.fixture(scope="session", params=CONFIGS.keys())
|
||||
def server_config(request):
|
||||
config = CONFIGS[request.param]
|
||||
# download model and tokenizer using transformers
|
||||
snapshot_download(config["model"])
|
||||
yield CONFIGS[request.param]
|
||||
|
||||
|
||||
# run this for each server config
|
||||
@pytest.fixture(scope="session")
|
||||
def server(request, server_config: ServerConfig):
|
||||
model = server_config["model"]
|
||||
args_for_model = server_config["arguments"]
|
||||
with RemoteOpenAIServer(model, ARGS + args_for_model,
|
||||
max_wait_seconds=480) as server:
|
||||
yield server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server: RemoteOpenAIServer):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
143
tests/tool_use/test_chat_completions.py
Normal file
143
tests/tool_use/test_chat_completions.py
Normal file
@ -0,0 +1,143 @@
|
||||
from typing import List
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from .utils import MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL
|
||||
|
||||
|
||||
# test: make sure chat completions without tools provided work even when tools
|
||||
# are enabled. This makes sure tool call chat templates work, AND that the tool
|
||||
# parser stream processing doesn't change the output of the model.
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_without_tools(client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_WITHOUT_TOOLS,
|
||||
temperature=0,
|
||||
max_tokens=150,
|
||||
model=model_name,
|
||||
logprobs=False)
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
output_text = chat_completion.choices[0].message.content
|
||||
|
||||
# check to make sure we got text
|
||||
assert output_text is not None
|
||||
assert len(output_text) > 0
|
||||
assert stop_reason != "tool_calls"
|
||||
|
||||
# check to make sure no tool calls were returned
|
||||
assert (choice.message.tool_calls is None
|
||||
or len(choice.message.tool_calls) == 0)
|
||||
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
messages=MESSAGES_WITHOUT_TOOLS,
|
||||
temperature=0,
|
||||
max_tokens=150,
|
||||
model=model_name,
|
||||
logprobs=False,
|
||||
stream=True,
|
||||
)
|
||||
chunks: List[str] = []
|
||||
finish_reason_count = 0
|
||||
role_sent: bool = False
|
||||
|
||||
# assemble streamed chunks
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# make sure the role is assistant
|
||||
if delta.role:
|
||||
assert not role_sent
|
||||
assert delta.role == 'assistant'
|
||||
role_sent = True
|
||||
|
||||
if delta.content:
|
||||
chunks.append(delta.content)
|
||||
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == choice.finish_reason
|
||||
|
||||
# make sure tool call chunks aren't being streamed
|
||||
assert not delta.tool_calls or len(delta.tool_calls) == 0
|
||||
|
||||
# make sure the role was sent, only 1 finish reason was sent, that chunks
|
||||
# were in fact sent, and that the chunks match non-streaming
|
||||
assert role_sent
|
||||
assert finish_reason_count == 1
|
||||
assert len(chunks)
|
||||
assert "".join(chunks) == output_text
|
||||
|
||||
|
||||
# test: conversation with tools enabled and provided that should not invoke
|
||||
# tools, to make sure we can still get normal chat completion responses
|
||||
# and that they won't be parsed as tools
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tools(client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_WITHOUT_TOOLS,
|
||||
temperature=0,
|
||||
max_tokens=150,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL],
|
||||
logprobs=False)
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
output_text = chat_completion.choices[0].message.content
|
||||
|
||||
# check to make sure we got text
|
||||
assert output_text is not None
|
||||
assert stop_reason != 'tool_calls'
|
||||
assert len(output_text) > 0
|
||||
|
||||
# check to make sure no tool calls were returned
|
||||
assert (choice.message.tool_calls is None
|
||||
or len(choice.message.tool_calls) == 0)
|
||||
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
messages=MESSAGES_WITHOUT_TOOLS,
|
||||
temperature=0,
|
||||
max_tokens=150,
|
||||
model=model_name,
|
||||
logprobs=False,
|
||||
tools=[WEATHER_TOOL],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks: List[str] = []
|
||||
finish_reason_count = 0
|
||||
role_sent: bool = False
|
||||
|
||||
# assemble streamed chunks
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# make sure the role is assistant
|
||||
if delta.role:
|
||||
assert delta.role == 'assistant'
|
||||
role_sent = True
|
||||
|
||||
if delta.content:
|
||||
chunks.append(delta.content)
|
||||
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
|
||||
# make sure tool call chunks aren't being streamed
|
||||
assert not delta.tool_calls or len(delta.tool_calls) == 0
|
||||
|
||||
# make sure the role was sent, only 1 finish reason was sent, that chunks
|
||||
# were in fact sent, and that the chunks match non-streaming
|
||||
assert role_sent
|
||||
assert finish_reason_count == 1
|
||||
assert chunk.choices[0].finish_reason == stop_reason
|
||||
assert chunk.choices[0].finish_reason != 'tool_calls'
|
||||
assert len(chunks)
|
||||
assert "".join(chunks) == output_text
|
193
tests/tool_use/test_parallel_tool_calls.py
Normal file
193
tests/tool_use/test_parallel_tool_calls.py
Normal file
@ -0,0 +1,193 @@
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL,
|
||||
WEATHER_TOOL)
|
||||
|
||||
|
||||
# test: getting the model to generate parallel tool calls (streaming/not)
|
||||
# when requested. NOTE that not all models may support this, so some exclusions
|
||||
# may be added in the future. e.g. llama 3.1 models are not designed to support
|
||||
# parallel tool calls.
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls(client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
temperature=0,
|
||||
max_tokens=200,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
|
||||
|
||||
# make sure 2 tool calls are present
|
||||
assert choice.message.role == "assistant"
|
||||
assert non_streamed_tool_calls is not None
|
||||
assert len(non_streamed_tool_calls) == 2
|
||||
|
||||
for tool_call in non_streamed_tool_calls:
|
||||
# make sure the tool includes a function and ID
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function is not None
|
||||
assert isinstance(tool_call.id, str)
|
||||
assert len(tool_call.id) > 16
|
||||
|
||||
# make sure the weather tool was called correctly
|
||||
assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
|
||||
assert isinstance(tool_call.function.arguments, str)
|
||||
|
||||
parsed_arguments = json.loads(tool_call.function.arguments)
|
||||
assert isinstance(parsed_arguments, Dict)
|
||||
assert isinstance(parsed_arguments.get("city"), str)
|
||||
assert isinstance(parsed_arguments.get("state"), str)
|
||||
|
||||
assert stop_reason == "tool_calls"
|
||||
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
temperature=0,
|
||||
max_tokens=200,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True)
|
||||
|
||||
role_name: Optional[str] = None
|
||||
finish_reason_count: int = 0
|
||||
|
||||
tool_call_names: List[str] = []
|
||||
tool_call_args: List[str] = []
|
||||
tool_call_idx: int = -1
|
||||
tool_call_id_count: int = 0
|
||||
|
||||
async for chunk in stream:
|
||||
|
||||
# if there's a finish reason make sure it's tools
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == 'tool_calls'
|
||||
|
||||
# if a role is being streamed make sure it wasn't already set to
|
||||
# something else
|
||||
if chunk.choices[0].delta.role:
|
||||
assert not role_name or role_name == 'assistant'
|
||||
role_name = 'assistant'
|
||||
|
||||
# if a tool call is streamed make sure there's exactly one
|
||||
# (based on the request parameters
|
||||
streamed_tool_calls = chunk.choices[0].delta.tool_calls
|
||||
|
||||
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||
|
||||
# make sure only one diff is present - correct even for parallel
|
||||
assert len(streamed_tool_calls) == 1
|
||||
tool_call = streamed_tool_calls[0]
|
||||
|
||||
# if a new tool is being called, set up empty arguments
|
||||
if tool_call.index != tool_call_idx:
|
||||
tool_call_idx = tool_call.index
|
||||
tool_call_args.append("")
|
||||
|
||||
# if a tool call ID is streamed, make sure one hasn't been already
|
||||
if tool_call.id:
|
||||
tool_call_id_count += 1
|
||||
assert (isinstance(tool_call.id, str)
|
||||
and (len(tool_call.id) > 16))
|
||||
|
||||
# if parts of the function start being streamed
|
||||
if tool_call.function:
|
||||
# if the function name is defined, set it. it should be streamed
|
||||
# IN ENTIRETY, exactly one time.
|
||||
if tool_call.function.name:
|
||||
assert isinstance(tool_call.function.name, str)
|
||||
tool_call_names.append(tool_call.function.name)
|
||||
|
||||
if tool_call.function.arguments:
|
||||
# make sure they're a string and then add them to the list
|
||||
assert isinstance(tool_call.function.arguments, str)
|
||||
|
||||
tool_call_args[
|
||||
tool_call.index] += tool_call.function.arguments
|
||||
|
||||
assert finish_reason_count == 1
|
||||
assert role_name == 'assistant'
|
||||
|
||||
assert (len(non_streamed_tool_calls) == len(tool_call_names) ==
|
||||
len(tool_call_args))
|
||||
|
||||
for i in range(2):
|
||||
assert non_streamed_tool_calls[i].function.name == tool_call_names[i]
|
||||
streamed_args = json.loads(tool_call_args[i])
|
||||
non_streamed_args = json.loads(
|
||||
non_streamed_tool_calls[i].function.arguments)
|
||||
assert streamed_args == non_streamed_args
|
||||
|
||||
|
||||
# test: providing parallel tool calls back to the model to get a response
|
||||
# (streaming/not)
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
|
||||
temperature=0,
|
||||
max_tokens=200,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
|
||||
assert choice.finish_reason != "tool_calls" # "stop" or "length"
|
||||
assert choice.message.role == "assistant"
|
||||
assert choice.message.tool_calls is None \
|
||||
or len(choice.message.tool_calls) == 0
|
||||
assert choice.message.content is not None
|
||||
assert "98" in choice.message.content # Dallas temp in tool response
|
||||
assert "78" in choice.message.content # Orlando temp in tool response
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
|
||||
temperature=0,
|
||||
max_tokens=200,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True)
|
||||
|
||||
chunks: List[str] = []
|
||||
finish_reason_count = 0
|
||||
role_sent: bool = False
|
||||
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
if delta.role:
|
||||
assert not role_sent
|
||||
assert delta.role == "assistant"
|
||||
role_sent = True
|
||||
|
||||
if delta.content:
|
||||
chunks.append(delta.content)
|
||||
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == choice.finish_reason
|
||||
|
||||
assert not delta.tool_calls or len(delta.tool_calls) == 0
|
||||
|
||||
assert role_sent
|
||||
assert finish_reason_count == 1
|
||||
assert len(chunks)
|
||||
assert "".join(chunks) == choice.message.content
|
192
tests/tool_use/test_tool_calls.py
Normal file
192
tests/tool_use/test_tool_calls.py
Normal file
@ -0,0 +1,192 @@
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE,
|
||||
SEARCH_TOOL, WEATHER_TOOL)
|
||||
|
||||
|
||||
# test: request a chat completion that should return tool calls, so we know they
|
||||
# are parsable
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_ASKING_FOR_TOOLS,
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
tool_calls = chat_completion.choices[0].message.tool_calls
|
||||
|
||||
# make sure a tool call is present
|
||||
assert choice.message.role == 'assistant'
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].type == 'function'
|
||||
assert tool_calls[0].function is not None
|
||||
assert isinstance(tool_calls[0].id, str)
|
||||
assert len(tool_calls[0].id) > 16
|
||||
|
||||
# make sure the weather tool was called (classic example) with arguments
|
||||
assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"]
|
||||
assert tool_calls[0].function.arguments is not None
|
||||
assert isinstance(tool_calls[0].function.arguments, str)
|
||||
|
||||
# make sure the arguments parse properly
|
||||
parsed_arguments = json.loads(tool_calls[0].function.arguments)
|
||||
assert isinstance(parsed_arguments, Dict)
|
||||
assert isinstance(parsed_arguments.get("city"), str)
|
||||
assert isinstance(parsed_arguments.get("state"), str)
|
||||
assert parsed_arguments.get("city") == "Dallas"
|
||||
assert parsed_arguments.get("state") == "TX"
|
||||
|
||||
assert stop_reason == "tool_calls"
|
||||
|
||||
function_name: Optional[str] = None
|
||||
function_args_str: str = ''
|
||||
tool_call_id: Optional[str] = None
|
||||
role_name: Optional[str] = None
|
||||
finish_reason_count: int = 0
|
||||
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=MESSAGES_ASKING_FOR_TOOLS,
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True)
|
||||
|
||||
async for chunk in stream:
|
||||
assert chunk.choices[0].index == 0
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == 'tool_calls'
|
||||
|
||||
# if a role is being streamed make sure it wasn't already set to
|
||||
# something else
|
||||
if chunk.choices[0].delta.role:
|
||||
assert not role_name or role_name == 'assistant'
|
||||
role_name = 'assistant'
|
||||
|
||||
# if a tool call is streamed make sure there's exactly one
|
||||
# (based on the request parameters
|
||||
streamed_tool_calls = chunk.choices[0].delta.tool_calls
|
||||
|
||||
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||
assert len(streamed_tool_calls) == 1
|
||||
tool_call = streamed_tool_calls[0]
|
||||
|
||||
# if a tool call ID is streamed, make sure one hasn't been already
|
||||
if tool_call.id:
|
||||
assert not tool_call_id
|
||||
tool_call_id = tool_call.id
|
||||
|
||||
# if parts of the function start being streamed
|
||||
if tool_call.function:
|
||||
# if the function name is defined, set it. it should be streamed
|
||||
# IN ENTIRETY, exactly one time.
|
||||
if tool_call.function.name:
|
||||
assert function_name is None
|
||||
assert isinstance(tool_call.function.name, str)
|
||||
function_name = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
assert isinstance(tool_call.function.arguments, str)
|
||||
function_args_str += tool_call.function.arguments
|
||||
|
||||
assert finish_reason_count == 1
|
||||
assert role_name == 'assistant'
|
||||
assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16)
|
||||
|
||||
# validate the name and arguments
|
||||
assert function_name == WEATHER_TOOL["function"]["name"]
|
||||
assert function_name == tool_calls[0].function.name
|
||||
assert isinstance(function_args_str, str)
|
||||
|
||||
# validate arguments
|
||||
streamed_args = json.loads(function_args_str)
|
||||
assert isinstance(streamed_args, Dict)
|
||||
assert isinstance(streamed_args.get("city"), str)
|
||||
assert isinstance(streamed_args.get("state"), str)
|
||||
assert streamed_args.get("city") == "Dallas"
|
||||
assert streamed_args.get("state") == "TX"
|
||||
|
||||
# make sure everything matches non-streaming except for ID
|
||||
assert function_name == tool_calls[0].function.name
|
||||
assert choice.message.role == role_name
|
||||
assert choice.message.tool_calls[0].function.name == function_name
|
||||
|
||||
# compare streamed with non-streamed args Dict-wise, not string-wise
|
||||
# because character-to-character comparison might not work e.g. the tool
|
||||
# call parser adding extra spaces or something like that. we care about the
|
||||
# dicts matching not byte-wise match
|
||||
assert parsed_arguments == streamed_args
|
||||
|
||||
|
||||
# test: providing tools and results back to model to get a non-tool response
|
||||
# (streaming/not)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_with_results(client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_WITH_TOOL_RESPONSE,
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
|
||||
assert choice.finish_reason != "tool_calls" # "stop" or "length"
|
||||
assert choice.message.role == "assistant"
|
||||
assert choice.message.tool_calls is None \
|
||||
or len(choice.message.tool_calls) == 0
|
||||
assert choice.message.content is not None
|
||||
assert "98" in choice.message.content # the temperature from the response
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
messages=MESSAGES_WITH_TOOL_RESPONSE,
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True)
|
||||
|
||||
chunks: List[str] = []
|
||||
finish_reason_count = 0
|
||||
role_sent: bool = False
|
||||
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
if delta.role:
|
||||
assert not role_sent
|
||||
assert delta.role == "assistant"
|
||||
role_sent = True
|
||||
|
||||
if delta.content:
|
||||
chunks.append(delta.content)
|
||||
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == choice.finish_reason
|
||||
|
||||
assert not delta.tool_calls or len(delta.tool_calls) == 0
|
||||
|
||||
assert role_sent
|
||||
assert finish_reason_count == 1
|
||||
assert len(chunks)
|
||||
assert "".join(chunks) == choice.message.content
|
215
tests/tool_use/utils.py
Normal file
215
tests/tool_use/utils.py
Normal file
@ -0,0 +1,215 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from openai.types.chat import (ChatCompletionMessageParam,
|
||||
ChatCompletionToolParam)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from tests.utils import VLLM_PATH
|
||||
|
||||
|
||||
class ServerConfig(TypedDict):
|
||||
model: str
|
||||
arguments: List[str]
|
||||
|
||||
|
||||
# universal args for all models go here. also good if you need to test locally
|
||||
# and change type or KV cache quantization or something.
|
||||
ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"]
|
||||
|
||||
CONFIGS: Dict[str, ServerConfig] = {
|
||||
"hermes": {
|
||||
"model":
|
||||
"NousResearch/Hermes-2-Pro-Llama-3-8B",
|
||||
"arguments": [
|
||||
"--tool-call-parser", "hermes", "--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja")
|
||||
]
|
||||
},
|
||||
"mistral": {
|
||||
"model":
|
||||
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"arguments": [
|
||||
"--tool-call-parser", "mistral", "--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"),
|
||||
"--ignore-patterns=\"consolidated.safetensors\""
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
WEATHER_TOOL: ChatCompletionToolParam = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to find the weather for, "
|
||||
"e.g. 'San Francisco'"
|
||||
},
|
||||
"state": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"the two-letter abbreviation for the state "
|
||||
"that the city is in, e.g. 'CA' which would "
|
||||
"mean 'California'"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SEARCH_TOOL: ChatCompletionToolParam = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name":
|
||||
"web_search",
|
||||
"description":
|
||||
"Search the internet and get a summary of the top "
|
||||
"10 webpages. Should only be used if you don't know "
|
||||
"the answer to a user query, and the results are likely"
|
||||
"to be able to be found with a web search",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"search_term": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The term to use in the search. This should"
|
||||
"ideally be keywords to search for, not a"
|
||||
"natural-language question"
|
||||
}
|
||||
},
|
||||
"required": ["search_term"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{
|
||||
"role":
|
||||
"system",
|
||||
"content":
|
||||
"You are a helpful assistant with access to tools. If a tool"
|
||||
" that you have would be helpful to answer a user query, "
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally."
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Hi! How are you?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"I'm doing great! How can I assist you?"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me a joke please?"
|
||||
}]
|
||||
|
||||
MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What is the weather in Dallas, Texas in Fahrenheit?"
|
||||
}]
|
||||
|
||||
MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What is the weather in Dallas, Texas in Fahrenheit?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"tool_calls": [{
|
||||
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name":
|
||||
WEATHER_TOOL["function"]["name"],
|
||||
"arguments":
|
||||
'{"city": "Dallas", "state": "TX", '
|
||||
'"unit": "fahrenheit"}'
|
||||
}
|
||||
}]
|
||||
}, {
|
||||
"role":
|
||||
"tool",
|
||||
"tool_call_id":
|
||||
"chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"content":
|
||||
"The weather in Dallas is 98 degrees fahrenheit, with partly"
|
||||
"cloudy skies and a low chance of rain."
|
||||
}]
|
||||
|
||||
MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What is the weather in Dallas, Texas and Orlando, Florida in "
|
||||
"Fahrenheit?"
|
||||
}]
|
||||
|
||||
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What is the weather in Dallas, Texas and Orlando, Florida in "
|
||||
"Fahrenheit?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"tool_calls": [{
|
||||
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name":
|
||||
WEATHER_TOOL["function"]["name"],
|
||||
"arguments":
|
||||
'{"city": "Dallas", "state": "TX", '
|
||||
'"unit": "fahrenheit"}'
|
||||
}
|
||||
}, {
|
||||
"id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name":
|
||||
WEATHER_TOOL["function"]["name"],
|
||||
"arguments":
|
||||
'{"city": "Orlando", "state": "Fl", '
|
||||
'"unit": "fahrenheit"}'
|
||||
}
|
||||
}]
|
||||
}, {
|
||||
"role":
|
||||
"tool",
|
||||
"tool_call_id":
|
||||
"chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"content":
|
||||
"The weather in Dallas TX is 98 degrees fahrenheit with mostly "
|
||||
"cloudy skies and a chance of rain in the evening."
|
||||
}, {
|
||||
"role":
|
||||
"tool",
|
||||
"tool_call_id":
|
||||
"chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
|
||||
"content":
|
||||
"The weather in Orlando FL is 78 degrees fahrenheit with clear"
|
||||
"skies."
|
||||
}]
|
@ -1,23 +1,28 @@
|
||||
import asyncio
|
||||
import codecs
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from functools import lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
|
||||
Mapping, Optional, Tuple, TypeVar, Union)
|
||||
Mapping, Optional, Tuple, TypeVar, Union, cast)
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from openai.types.chat import ChatCompletionContentPartImageParam
|
||||
from openai.types.chat import (ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionContentPartImageParam)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
|
||||
from openai.types.chat import ChatCompletionContentPartTextParam
|
||||
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
|
||||
ChatCompletionContentPartTextParam)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
|
||||
from openai.types.chat import (ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionToolMessageParam)
|
||||
# yapf: enable
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from pydantic import ConfigDict, TypeAdapter
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import Required, TypeAlias, TypedDict
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
@ -54,7 +59,8 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
|
||||
|
||||
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
|
||||
CustomChatCompletionContentPartParam, ]
|
||||
ChatCompletionContentPartRefusalParam,
|
||||
CustomChatCompletionContentPartParam]
|
||||
|
||||
|
||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
@ -72,15 +78,33 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
same role.
|
||||
"""
|
||||
|
||||
tool_call_id: Optional[str]
|
||||
"""Tool call that this message is responding to."""
|
||||
|
||||
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
|
||||
"""The tool calls generated by the model, such as function calls."""
|
||||
|
||||
|
||||
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
|
||||
CustomChatCompletionMessageParam]
|
||||
|
||||
|
||||
# TODO: Make fields ReadOnly once mypy supports it
|
||||
class ConversationMessage(TypedDict):
|
||||
role: str
|
||||
content: str
|
||||
class ConversationMessage(TypedDict, total=False):
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Optional[str]
|
||||
"""The contents of the message"""
|
||||
|
||||
tool_call_id: Optional[str]
|
||||
"""Tool call that this message is responding to."""
|
||||
|
||||
name: Optional[str]
|
||||
"""The name of the function to call"""
|
||||
|
||||
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
|
||||
"""The tool calls generated by the model, such as function calls."""
|
||||
|
||||
|
||||
ModalityStr = Literal["image", "audio"]
|
||||
@ -319,9 +343,11 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
|
||||
return "\n".join(missing_placeholders + [text_prompt])
|
||||
|
||||
|
||||
_TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
|
||||
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam)
|
||||
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
|
||||
# No need to validate using Pydantic again
|
||||
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
||||
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
|
||||
|
||||
def _parse_chat_message_content_parts(
|
||||
@ -336,10 +362,10 @@ def _parse_chat_message_content_parts(
|
||||
for part in parts:
|
||||
part_type = part["type"]
|
||||
if part_type == "text":
|
||||
text = _TextParser.validate_python(part)["text"]
|
||||
text = _TextParser(part)["text"]
|
||||
texts.append(text)
|
||||
elif part_type == "image_url":
|
||||
image_url = _ImageParser.validate_python(part)["image_url"]
|
||||
image_url = _ImageParser(part)["image_url"]
|
||||
|
||||
if image_url.get("detail", "auto") != "auto":
|
||||
logger.warning(
|
||||
@ -348,7 +374,7 @@ def _parse_chat_message_content_parts(
|
||||
|
||||
mm_parser.parse_image(image_url["url"])
|
||||
elif part_type == "audio_url":
|
||||
audio_url = _AudioParser.validate_python(part)["audio_url"]
|
||||
audio_url = _AudioParser(part)["audio_url"]
|
||||
|
||||
mm_parser.parse_audio(audio_url["url"])
|
||||
else:
|
||||
@ -363,6 +389,11 @@ def _parse_chat_message_content_parts(
|
||||
return [ConversationMessage(role=role, content=text_prompt)]
|
||||
|
||||
|
||||
# No need to validate using Pydantic again
|
||||
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
|
||||
_ToolParser = partial(cast, ChatCompletionToolMessageParam)
|
||||
|
||||
|
||||
def _parse_chat_message_content(
|
||||
message: ChatCompletionMessageParam,
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
@ -371,16 +402,34 @@ def _parse_chat_message_content(
|
||||
content = message.get("content")
|
||||
|
||||
if content is None:
|
||||
return []
|
||||
if isinstance(content, str):
|
||||
return [ConversationMessage(role=role, content=content)]
|
||||
content = []
|
||||
elif isinstance(content, str):
|
||||
content = [
|
||||
ChatCompletionContentPartTextParam(type="text", text=content)
|
||||
]
|
||||
|
||||
return _parse_chat_message_content_parts(
|
||||
result = _parse_chat_message_content_parts(
|
||||
role,
|
||||
content, # type: ignore
|
||||
mm_tracker,
|
||||
)
|
||||
|
||||
for result_msg in result:
|
||||
if role == 'assistant':
|
||||
parsed_msg = _AssistantParser(message)
|
||||
|
||||
if "tool_calls" in parsed_msg:
|
||||
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
|
||||
elif role == "tool":
|
||||
parsed_msg = _ToolParser(message)
|
||||
if "tool_call_id" in parsed_msg:
|
||||
result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
|
||||
|
||||
if "name" in message and isinstance(message["name"], str):
|
||||
result_msg["name"] = message["name"]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_chat_messages(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
@ -428,6 +477,20 @@ def apply_chat_template(
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one.")
|
||||
|
||||
# per the Transformers docs & maintainers, tool call arguments in
|
||||
# assistant-role messages with tool_calls need to be dicts not JSON str -
|
||||
# this is how tool-use chat templates will expect them moving forwards
|
||||
# so, for messages that have tool_calls, parse the string (which we get
|
||||
# from openAI format) to dict
|
||||
for message in conversation:
|
||||
if (message["role"] == "assistant" and "tool_calls" in message
|
||||
and isinstance(message["tool_calls"], list)):
|
||||
|
||||
for i in range(len(message["tool_calls"])):
|
||||
args: str = message["tool_calls"][i]["function"]["arguments"]
|
||||
parsed_args: Dict = json.loads(args)
|
||||
message["tool_calls"][i]["function"]["arguments"] = parsed_args
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
chat_template=chat_template,
|
||||
|
@ -233,7 +233,7 @@ def mount_metrics(app: FastAPI):
|
||||
metrics_route = Mount("/metrics", make_asgi_app())
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
@ -283,11 +283,14 @@ async def show_version():
|
||||
@router.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
|
||||
generator = await openai_serving_chat.create_chat_completion(
|
||||
request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
@ -422,7 +425,8 @@ async def init_app(
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser)
|
||||
openai_serving_completion = OpenAIServingCompletion(
|
||||
async_engine_client,
|
||||
model_config,
|
||||
|
@ -163,6 +163,24 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
help="If specified, will run the OpenAI frontend server in the same "
|
||||
"process as the model serving engine.")
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-auto-tool-choice",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
"Enable auto tool choice for supported models. Use --tool-call-parser"
|
||||
"to specify which parser to use")
|
||||
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
choices=["mistral", "hermes"],
|
||||
default=None,
|
||||
help=
|
||||
"Select the tool call parser depending on the model that you're using."
|
||||
" This is used to parse the model-generated tool call into OpenAI API "
|
||||
"format. Required for --enable-auto-tool-choice.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
|
@ -5,8 +5,9 @@ from argparse import Namespace
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from openai.types.chat import ChatCompletionContentPartParam
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Annotated
|
||||
from typing_extensions import Annotated, Required, TypedDict
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||
@ -35,6 +36,26 @@ assert _LONG_INFO.min == _MOCK_LONG_INFO.min
|
||||
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
|
||||
|
||||
|
||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
"""Enables custom roles in the Chat Completion API."""
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Union[str, List[ChatCompletionContentPartParam]]
|
||||
"""The contents of the message."""
|
||||
|
||||
name: str
|
||||
"""An optional name for the participant.
|
||||
|
||||
Provides the model information to differentiate between participants of the
|
||||
same role.
|
||||
"""
|
||||
|
||||
tool_call_id: Optional[str]
|
||||
|
||||
tool_calls: Optional[List[dict]]
|
||||
|
||||
|
||||
class OpenAIBaseModel(BaseModel):
|
||||
# OpenAI API does not allow extra fields
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
@ -145,8 +166,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 1.0
|
||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||
tool_choice: Optional[Union[Literal["none"],
|
||||
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
|
||||
ChatCompletionNamedToolChoiceParam]] = "none"
|
||||
|
||||
# NOTE this will be ignored by VLLM -- the model determines the behavior
|
||||
parallel_tool_calls: Optional[bool] = False
|
||||
user: Optional[str] = None
|
||||
|
||||
# doc: begin-chat-completion-sampling-params
|
||||
@ -328,6 +352,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_guided_decoding_count(cls, data):
|
||||
if isinstance(data, ValueError):
|
||||
raise data
|
||||
|
||||
guide_count = sum([
|
||||
"guided_json" in data and data["guided_json"] is not None,
|
||||
"guided_regex" in data and data["guided_regex"] is not None,
|
||||
@ -339,21 +366,61 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"You can only use one kind of guided decoding "
|
||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||
# you can only either use guided decoding or tools, not both
|
||||
if guide_count > 1 and "tool_choice" in data and data[
|
||||
"tool_choice"] != "none":
|
||||
if guide_count > 1 and data.get("tool_choice",
|
||||
"none") not in ("none", "auto"):
|
||||
raise ValueError(
|
||||
"You can only either use guided decoding or tools, not both.")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tool_choice(cls, data):
|
||||
if "tool_choice" in data and data["tool_choice"] != "none":
|
||||
if not isinstance(data["tool_choice"], dict):
|
||||
raise ValueError("Currently only named tools are supported.")
|
||||
def check_tool_usage(cls, data):
|
||||
|
||||
# if "tool_choice" is not specified but tools are provided,
|
||||
# default to "auto" tool_choice
|
||||
if "tool_choice" not in data and "tools" in data:
|
||||
data["tool_choice"] = "auto"
|
||||
|
||||
# if "tool_choice" is specified -- validation
|
||||
if "tool_choice" in data:
|
||||
|
||||
# ensure that if "tool choice" is specified, tools are present
|
||||
if "tools" not in data or data["tools"] is None:
|
||||
raise ValueError(
|
||||
"When using `tool_choice`, `tools` must be set.")
|
||||
|
||||
# make sure that tool choice is either a named tool
|
||||
# OR that it's set to "auto"
|
||||
if data["tool_choice"] != "auto" and not isinstance(
|
||||
data["tool_choice"], dict):
|
||||
raise ValueError(
|
||||
"`tool_choice` must either be a named tool or \"auto\". "
|
||||
"`tool_choice=\"none\" is not supported.")
|
||||
|
||||
# ensure that if "tool_choice" is specified as an object,
|
||||
# it matches a valid tool
|
||||
if isinstance(data["tool_choice"], dict):
|
||||
valid_tool = False
|
||||
specified_function = data["tool_choice"]["function"]
|
||||
if not specified_function:
|
||||
raise ValueError(
|
||||
"Incorrectly formatted `tool_choice`. Should be like "
|
||||
"`{\"type\": \"function\","
|
||||
" \"function\": {\"name\": \"my_function\"}}`")
|
||||
specified_function_name = specified_function["name"]
|
||||
if not specified_function_name:
|
||||
raise ValueError(
|
||||
"Incorrectly formatted `tool_choice`. Should be like "
|
||||
"`{\"type\": \"function\", "
|
||||
"\"function\": {\"name\": \"my_function\"}}`")
|
||||
for tool in data["tools"]:
|
||||
if tool["function"]["name"] == specified_function_name:
|
||||
valid_tool = True
|
||||
break
|
||||
if not valid_tool:
|
||||
raise ValueError(
|
||||
"The tool specified in `tool_choice` does not match any"
|
||||
" of the specified `tools`")
|
||||
return data
|
||||
|
||||
|
||||
@ -413,7 +480,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
)
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
||||
default=None,
|
||||
description=("If specified, the output will follow the JSON schema."),
|
||||
description="If specified, the output will follow the JSON schema.",
|
||||
)
|
||||
guided_regex: Optional[str] = Field(
|
||||
default=None,
|
||||
@ -633,9 +700,41 @@ class ToolCall(OpenAIBaseModel):
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class DeltaFunctionCall(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
# a tool call delta where everything is optional
|
||||
class DeltaToolCall(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
|
||||
type: Literal["function"] = "function"
|
||||
index: int
|
||||
function: Optional[DeltaFunctionCall] = None
|
||||
|
||||
|
||||
# the initial delta that gets sent once a new tool call is started;
|
||||
class InitialDeltaToolCall(DeltaToolCall):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
|
||||
type: Literal["function"] = "function"
|
||||
index: int
|
||||
|
||||
|
||||
class ExtractedToolCallInformation(BaseModel):
|
||||
# indicate if tools were called
|
||||
tools_called: bool
|
||||
|
||||
# extracted tool calls
|
||||
tool_calls: List[ToolCall]
|
||||
|
||||
# content - per OpenAI spec, content AND tool calls can be returned rarely
|
||||
# But some models will do this intentionally
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class ChatMessage(OpenAIBaseModel):
|
||||
role: str
|
||||
content: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
@ -657,7 +756,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
# per OpenAI spec this is the default
|
||||
finish_reason: Optional[str] = "stop"
|
||||
# not part of the OpenAI spec but included in vLLM for legacy reasons
|
||||
stop_reason: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
@ -674,7 +775,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
|
||||
class DeltaMessage(OpenAIBaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
tool_calls: List[DeltaToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
|
@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional
|
||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
|
||||
Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
@ -18,15 +20,18 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
FunctionCall, ToolCall, UsageInfo)
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath,
|
||||
TextTokensPrompt)
|
||||
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
|
||||
MistralToolParser,
|
||||
ToolParser)
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
@ -38,19 +43,19 @@ logger = init_logger(__name__)
|
||||
|
||||
class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
async_engine_client: AsyncEngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
response_role: str,
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
def __init__(self,
|
||||
async_engine_client: AsyncEngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
response_role: str,
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_auto_tools: bool = False,
|
||||
tool_parser: Optional[str] = None):
|
||||
super().__init__(async_engine_client=async_engine_client,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
@ -60,10 +65,27 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
|
||||
self.response_role = response_role
|
||||
|
||||
# If this is None we use the tokenizer's default chat template
|
||||
self.use_tool_use_model_template = False
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
|
||||
# set up tool use
|
||||
self.enable_auto_tools: bool = enable_auto_tools
|
||||
if self.enable_auto_tools:
|
||||
logger.info(
|
||||
"\"auto\" tool choice has been enabled please note that while"
|
||||
" the parallel_tool_calls client option is preset for "
|
||||
"compatibility reasons, it will be ignored.")
|
||||
|
||||
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
||||
if self.enable_auto_tools:
|
||||
if tool_parser == "mistral":
|
||||
self.tool_parser = MistralToolParser
|
||||
elif tool_parser == "hermes":
|
||||
self.tool_parser = Hermes2ProToolParser
|
||||
else:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||
"--tool-call-parser")
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
@ -76,11 +98,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
for the API specification. This API mimics the OpenAI
|
||||
ChatCompletion API.
|
||||
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- function_call (Users should implement this by themselves)
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
logger.error("Error with model %s", error_check_ret)
|
||||
return error_check_ret
|
||||
|
||||
try:
|
||||
@ -119,6 +140,20 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logger.error("Error in loading multi-modal data: %s", e)
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# validation for OpenAI tools
|
||||
# tool_choice = "required" is not supported
|
||||
if request.tool_choice == "required":
|
||||
return self.create_error_response(
|
||||
"tool_choice = \"required\" is not supported!")
|
||||
|
||||
# "auto" tools requires --enable-auto-tool-choice
|
||||
# and --tool-call-parser
|
||||
if request.tool_choice == "auto" and not (
|
||||
self.enable_auto_tools and self.tool_parser is not None):
|
||||
return self.create_error_response(
|
||||
"\"auto\" tool choice requires "
|
||||
"--enable-auto-tool-choice and --tool-call-parser to be set")
|
||||
|
||||
request_id = f"chat-{random_uuid()}"
|
||||
try:
|
||||
guided_decode_logits_processor = (
|
||||
@ -187,6 +222,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request, result_generator, request_id, conversation, tokenizer)
|
||||
|
||||
try:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, result_generator, request_id, conversation, tokenizer)
|
||||
@ -219,6 +255,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
previous_num_tokens = [0] * num_choices
|
||||
finish_reason_sent = [False] * num_choices
|
||||
|
||||
tool_parser: Optional[ToolParser] = self.tool_parser(
|
||||
tokenizer) if self.tool_parser else None
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
# We need to do it here, because if there are exceptions in
|
||||
@ -228,6 +267,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# Send first response for each request.n (index) with
|
||||
# the role
|
||||
role = self.get_chat_request_role(request)
|
||||
|
||||
# NOTE num_choices defaults to 1 so this usually executes
|
||||
# once per request
|
||||
for i in range(num_choices):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
@ -240,14 +282,18 @@ class OpenAIServingChat(OpenAIServing):
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# if usage should be included
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
if (request.stream_options.continuous_usage_stats):
|
||||
# if continuous usage stats are requested, add it
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
usage = UsageInfo(prompt_tokens=prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=prompt_tokens)
|
||||
chunk.usage = usage
|
||||
# otherwise don't
|
||||
else:
|
||||
chunk.usage = None
|
||||
|
||||
@ -257,7 +303,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# Send response to echo the input portion of the
|
||||
# last message
|
||||
if request.echo:
|
||||
last_msg_content = ""
|
||||
last_msg_content: Optional[str] = ""
|
||||
if conversation and conversation[-1].get(
|
||||
"content") and conversation[-1].get(
|
||||
"role") == role:
|
||||
@ -298,6 +344,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
first_iteration = False
|
||||
|
||||
for output in res.outputs:
|
||||
|
||||
i = output.index
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
@ -320,19 +367,49 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logprobs = None
|
||||
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
delta_message: Optional[DeltaMessage] = None
|
||||
|
||||
# handle streaming deltas for tools with named tool_choice
|
||||
if (request.tool_choice and type(request.tool_choice) is
|
||||
ChatCompletionNamedToolChoiceParam):
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(function=DeltaFunctionCall(
|
||||
name=request.tool_choice.function.name,
|
||||
arguments=delta_text),
|
||||
index=i)
|
||||
])
|
||||
|
||||
# handle streaming deltas for tools with "auto" tool choice
|
||||
elif (self._should_stream_with_auto_tool_parsing(request)
|
||||
and tool_parser):
|
||||
delta_message = (
|
||||
tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_texts[i],
|
||||
current_text=output.text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids= \
|
||||
output.token_ids[
|
||||
:-1 * len(delta_token_ids)
|
||||
],
|
||||
current_token_ids=output.token_ids,
|
||||
delta_token_ids=delta_token_ids
|
||||
)
|
||||
)
|
||||
|
||||
# handle streaming just a content delta
|
||||
else:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
|
||||
# set the previous values for the next iteration
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
|
||||
if request.tool_choice and type(
|
||||
request.tool_choice
|
||||
) is ChatCompletionNamedToolChoiceParam:
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
ToolCall(function=FunctionCall(
|
||||
name=request.tool_choice.function.name,
|
||||
arguments=delta_text))
|
||||
])
|
||||
else:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
# if the message delta is None (e.g. because it was a
|
||||
# "control token" for tool calls or the parser otherwise
|
||||
# wasn't ready to send a token, then
|
||||
# get the next token without streaming a chunk
|
||||
if delta_message is None:
|
||||
continue
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Send token-by-token response for each request.n
|
||||
@ -348,6 +425,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
if (request.stream_options.continuous_usage_stats):
|
||||
@ -365,14 +444,55 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# if the model is finished generating
|
||||
else:
|
||||
# check to make sure we haven't "forgotten" to stream
|
||||
# any tokens that were generated but previously
|
||||
# matched by partial json parsing
|
||||
# only happens if we are NOT using guided decoding
|
||||
if tool_parser:
|
||||
index = len(
|
||||
tool_parser.prev_tool_call_arr) - 1 if len(
|
||||
tool_parser.prev_tool_call_arr) > 0 else 0
|
||||
else:
|
||||
index = 0
|
||||
|
||||
if self._should_check_for_unstreamed_tool_arg_tokens(
|
||||
delta_message, output) and tool_parser:
|
||||
# get the expected call based on partial JSON
|
||||
# parsing which "autocompletes" the JSON
|
||||
expected_call = json.dumps(
|
||||
tool_parser.prev_tool_call_arr[index].get(
|
||||
"arguments", {}))
|
||||
|
||||
# get what we've streamed so for for arguments
|
||||
# for the current tool
|
||||
actual_call = tool_parser.streamed_args_for_tool[
|
||||
index]
|
||||
|
||||
# check to see if there's anything left to stream
|
||||
remaining_call = expected_call.replace(
|
||||
actual_call, "", 1)
|
||||
|
||||
# set that as a delta message
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=remaining_call).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
|
||||
# Send the finish response for each request.n only once
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
finish_reason=output.finish_reason
|
||||
if not (tool_parser
|
||||
and len(tool_parser.prev_tool_call_arr))
|
||||
else "tool_calls",
|
||||
stop_reason=output.stop_reason)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
@ -398,6 +518,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
yield f"data: {data}\n\n"
|
||||
finish_reason_sent[i] = True
|
||||
|
||||
# once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
final_usage = UsageInfo(
|
||||
@ -419,6 +541,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.error("error in chat completion stream generator: %s", e)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
@ -463,8 +586,21 @@ class OpenAIServingChat(OpenAIServing):
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
if request.tool_choice and type(
|
||||
# by default, tools are not used.
|
||||
tools_called = False
|
||||
|
||||
# if auto tools are not enabled, and a named tool choice using
|
||||
# outlines is not being used
|
||||
if not (self.enable_auto_tools
|
||||
or not self.tool_parser) and not isinstance(
|
||||
request.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam):
|
||||
message = ChatMessage(role=role, content=output.text)
|
||||
|
||||
# if the request uses tools and specified a tool choice
|
||||
elif request.tool_choice and type(
|
||||
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
content="",
|
||||
@ -473,14 +609,47 @@ class OpenAIServingChat(OpenAIServing):
|
||||
name=request.tool_choice.function.name,
|
||||
arguments=output.text))
|
||||
])
|
||||
tools_called = True
|
||||
|
||||
# if the request doesn't use tool choice
|
||||
# OR specifies to not use a tool
|
||||
elif not request.tool_choice or request.tool_choice == "none":
|
||||
|
||||
message = ChatMessage(role=role, content=output.text)
|
||||
|
||||
# handle when there are tools and tool choice is auto
|
||||
elif request.tools and (
|
||||
request.tool_choice == "auto"
|
||||
or request.tool_choice is None) and self.enable_auto_tools \
|
||||
and self.tool_parser:
|
||||
|
||||
tool_parser = self.tool_parser(tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(output.text)
|
||||
tools_called = tool_call_info.tools_called
|
||||
if tool_call_info.tools_called:
|
||||
message = ChatMessage(role=role,
|
||||
content=tool_call_info.content,
|
||||
tool_calls=tool_call_info.tool_calls)
|
||||
|
||||
else:
|
||||
# FOR NOW make it a chat message; we will have to detect
|
||||
# the type to make it later.
|
||||
message = ChatMessage(role=role, content=output.text)
|
||||
|
||||
# undetermined case that is still important to handle
|
||||
else:
|
||||
logger.error(
|
||||
"Error in chat_completion_full_generator - cannot determine"
|
||||
" if tools should be extracted. Returning a standard chat "
|
||||
"completion.")
|
||||
message = ChatMessage(role=role, content=output.text)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
finish_reason="tool_calls" if tools_called else
|
||||
output.finish_reason if output.finish_reason else "stop",
|
||||
stop_reason=output.stop_reason)
|
||||
choices.append(choice_data)
|
||||
|
||||
@ -488,10 +657,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
last_msg_content = ""
|
||||
if conversation and conversation[-1].get(
|
||||
"content") and conversation[-1].get("role") == role:
|
||||
last_msg_content = conversation[-1]["content"]
|
||||
last_msg_content = conversation[-1]["content"] or ""
|
||||
|
||||
for choice in choices:
|
||||
full_message = last_msg_content + choice.message.content
|
||||
full_message = last_msg_content + (choice.message.content
|
||||
or "")
|
||||
choice.message.content = full_message
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
@ -574,3 +744,38 @@ class OpenAIServingChat(OpenAIServing):
|
||||
))
|
||||
|
||||
return ChatCompletionLogProbs(content=logprobs_content)
|
||||
|
||||
def _should_stream_with_auto_tool_parsing(self,
|
||||
request: ChatCompletionRequest):
|
||||
"""
|
||||
Utility function to check if streamed tokens should go through the tool
|
||||
call parser that was configured.
|
||||
|
||||
We only want to do this IF user-provided tools are set, a tool parser
|
||||
is configured, "auto" tool choice is enabled, and the request's tool
|
||||
choice field indicates that "auto" tool choice should be used.
|
||||
"""
|
||||
return (request.tools and self.tool_parser and self.enable_auto_tools
|
||||
and request.tool_choice in ['auto', None])
|
||||
|
||||
def _should_check_for_unstreamed_tool_arg_tokens(
|
||||
self,
|
||||
delta_message: Optional[DeltaMessage],
|
||||
output: CompletionOutput,
|
||||
) -> bool:
|
||||
"""
|
||||
Check to see if we should check for unstreamed tool arguments tokens.
|
||||
This is only applicable when auto tool parsing is enabled, the delta
|
||||
is a tool call with arguments.
|
||||
"""
|
||||
|
||||
# yapf: disable
|
||||
return bool(
|
||||
# if there is a delta message that includes tool calls which
|
||||
# include a function that has arguments
|
||||
self.enable_auto_tools and self.tool_parser and delta_message
|
||||
and delta_message.tool_calls and delta_message.tool_calls[0]
|
||||
and delta_message.tool_calls[0].function
|
||||
and delta_message.tool_calls[0].function.arguments is not None
|
||||
and output.finish_reason is not None
|
||||
)
|
||||
|
@ -43,7 +43,11 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
request_logger=request_logger)
|
||||
|
||||
# If this is None we use the tokenizer's default chat template
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
# the list of commonly-used chat template names for HF named templates
|
||||
hf_chat_templates: List[str] = ['default', 'tool_use']
|
||||
self.chat_template = chat_template \
|
||||
if chat_template in hf_chat_templates \
|
||||
else load_chat_template(chat_template)
|
||||
|
||||
async def create_tokenize(
|
||||
self,
|
||||
|
5
vllm/entrypoints/openai/tool_parsers/__init__.py
Normal file
5
vllm/entrypoints/openai/tool_parsers/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .abstract_tool_parser import ToolParser
|
||||
from .hermes_tool_parser import Hermes2ProToolParser
|
||||
from .mistral_tool_parser import MistralToolParser
|
||||
|
||||
__all__ = ["ToolParser", "Hermes2ProToolParser", "MistralToolParser"]
|
58
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
Normal file
58
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
Normal file
@ -0,0 +1,58 @@
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (DeltaMessage,
|
||||
ExtractedToolCallInformation)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ToolParser:
|
||||
"""
|
||||
Abstract ToolParser class that should not be used directly. Provided
|
||||
properties and methods should be used in
|
||||
derived classes.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
# the index of the tool call that is currently being parsed
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.current_tool_initial_sent: bool = False
|
||||
self.streamed_args_for_tool: List[str] = []
|
||||
|
||||
self.model_tokenizer = tokenizer
|
||||
|
||||
def extract_tool_calls(self,
|
||||
model_output: str) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Static method that should be implemented for extracting tool calls from
|
||||
a complete model-generated string.
|
||||
Used for non-streaming responses where we have the entire model response
|
||||
available before sending to the client.
|
||||
Static because it's stateless.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"AbstractToolParser.extract_tool_calls has not been implemented!")
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Instance method that should be implemented for extracting tool calls
|
||||
from an incomplete response; for use when handling tool calls and
|
||||
streaming. Has to be an instance method because it requires state -
|
||||
the current tokens/diffs, but also the information about what has
|
||||
previously been parsed and extracted (see constructor)
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"AbstractToolParser.extract_tool_calls_streaming has not been "
|
||||
"implemented!")
|
344
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Normal file
344
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Normal file
@ -0,0 +1,344 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
InitialDeltaToolCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Hermes2ProToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
logger.error(
|
||||
"Detected Mistral tokenizer when using a Hermes model")
|
||||
self.model_tokenizer = self.model_tokenizer.tokenizer
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent = False
|
||||
self.current_tool_initial_sent: bool = False
|
||||
self.streamed_args_for_tool: List[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
|
||||
self.scratch_pad_regex = re.compile(
|
||||
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_call_start_token_id: int = self.model_tokenizer.vocab[
|
||||
self.tool_call_start_token]
|
||||
self.tool_call_end_token_id: int = self.model_tokenizer.vocab[
|
||||
self.tool_call_end_token]
|
||||
if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
|
||||
raise RuntimeError(
|
||||
"Hermes 2 Pro Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
def extract_tool_calls(self,
|
||||
model_output: str) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
else:
|
||||
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = (
|
||||
self.tool_call_regex.findall(model_output))
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
raw_function_calls = [
|
||||
json.loads(match[0] if match[0] else match[1])
|
||||
for match in function_call_tuples
|
||||
]
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(function_call["arguments"])))
|
||||
for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_call_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in extracting tool call from response %s",
|
||||
e)
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_call_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
prev_tool_end_count = previous_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
cur_tool_end_count = current_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
if delta_text != self.tool_call_end_token:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# case: if tool open & close tag counts don't match, we're doing
|
||||
# imaginary "else" block here
|
||||
# something with tools with this diff.
|
||||
# flags for partial JSON parting. exported constants from
|
||||
# "Allow" are handled via BIT MASK
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.current_tool_initial_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count):
|
||||
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count > prev_tool_end_count):
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
if diff:
|
||||
diff = json.dumps(diff).replace(
|
||||
self.streamed_args_for_tool[self.current_tool_id], "")
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s", diff)
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= diff
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
try:
|
||||
|
||||
current_tool_call = partial_json_parser.loads(
|
||||
tool_call_portion or "{}",
|
||||
flags) if tool_call_portion else None
|
||||
logger.debug("Parsed tool call %s", current_tool_call)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# case - we haven't sent the initial delta with the tool call ID
|
||||
# (it will be sent)
|
||||
if not self.current_tool_initial_sent:
|
||||
self.current_tool_initial_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
InitialDeltaToolCall(
|
||||
index=self.current_tool_id).model_dump(
|
||||
exclude_none=True)
|
||||
])
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name: Union[str, None] = current_tool_call.get("name")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
else:
|
||||
return None
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = DeltaMessage(content=delta_text) \
|
||||
if text_portion is not None else None
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug("Trying to parse current tool call with ID %s",
|
||||
self.current_tool_id)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = (
|
||||
self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error("should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything.")
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
|
||||
cur_arguments_json = json.dumps(cur_arguments)
|
||||
logger.debug("finding %s in %s", delta_text,
|
||||
cur_arguments_json)
|
||||
|
||||
# get the location where previous args differ from current
|
||||
args_delta_start_loc = cur_arguments_json.index(delta_text) \
|
||||
+ len(delta_text)
|
||||
|
||||
# use that to find the actual delta
|
||||
arguments_delta = cur_arguments_json[:args_delta_start_loc]
|
||||
logger.debug("First tokens in arguments received: %s",
|
||||
arguments_delta)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= arguments_delta
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
logger.debug("Searching for diff between\n%s", cur_args_json)
|
||||
logger.debug("and\n%s", prev_args_json)
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
logger.debug("got argument diff %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= argument_diff
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[self.current_tool_id] = \
|
||||
current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
return None # do not stream a delta. skip this token ID.
|
293
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Normal file
293
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Normal file
@ -0,0 +1,293 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
InitialDeltaToolCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MistralToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
|
||||
examples/tool_chat_template_mistral.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser gmistral are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
self.model_tokenizer = self.model_tokenizer.tokenizer
|
||||
else:
|
||||
logger.info("Non-Mistral tokenizer detected when using a Mistral "
|
||||
"model...")
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.current_tool_initial_sent: bool = False
|
||||
self.streamed_args_for_tool: List[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.model_tokenizer.vocab[self.bot_token]
|
||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||
|
||||
def extract_tool_calls(self,
|
||||
model_output: str) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response. Requires
|
||||
find-and-replacing single quotes with double quotes for JSON parsing,
|
||||
make sure your tool call arguments don't ever include quotes!
|
||||
"""
|
||||
|
||||
# case -- if a tool call token is not present, return a text response
|
||||
if self.bot_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
try:
|
||||
|
||||
# use a regex to find the tool call. remove the BOT token
|
||||
# and make sure to replace single quotes with double quotes
|
||||
raw_tool_call = self.tool_call_regex.findall(
|
||||
model_output.replace(self.bot_token, ""))[0]
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
tool_calls: List[ToolCall] = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(raw_function_call["arguments"])))
|
||||
for raw_function_call in function_call_arr
|
||||
]
|
||||
|
||||
# get any content before the tool call
|
||||
content = model_output.split(self.bot_token)[0]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if len(content) > 0 else None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in extracting tool call from response: %s", e)
|
||||
print("ERROR", e)
|
||||
# return information to just treat the tool call as regular JSON
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
# if the tool call token is not in the tokens generated so far, append
|
||||
# output to contents since it's not a tool
|
||||
if self.bot_token_id not in current_token_ids:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# if the tool call token ID IS in the tokens generated so far, that
|
||||
# means we're parsing as tool calls now
|
||||
|
||||
# handle if we detected the BOT token which means the start of tool
|
||||
# calling
|
||||
if (self.bot_token_id in delta_token_ids
|
||||
and len(delta_token_ids) == 1):
|
||||
# if it's the only token, return None, so we don't send a chat
|
||||
# completion any don't send a control token
|
||||
return None
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
|
||||
# replace BOT token with empty string, and convert single quotes
|
||||
# to double to allow parsing as JSON since mistral uses single
|
||||
# quotes instead of double for tool calls
|
||||
parsable_arr = current_text.split(self.bot_token)[1]
|
||||
|
||||
# tool calls are generated in an array, so do partial JSON
|
||||
# parsing on the entire array
|
||||
try:
|
||||
tool_call_arr: List[Dict] = partial_json_parser.loads(
|
||||
parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
|
||||
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
|
||||
if len(tool_call_arr) > 0 else {}
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (len(tool_call_arr) > 0
|
||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
diff: Union[str, None] = current_tool_call.get("arguments")
|
||||
|
||||
if diff:
|
||||
diff = json.dumps(diff).replace(
|
||||
self.streamed_args_for_tool[self.current_tool_id],
|
||||
"")
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.current_tool_initial_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# case: update an existing tool - this is handled below
|
||||
|
||||
# if the current tool initial data incl. the id, type=function
|
||||
# and idx not sent, send that
|
||||
if not self.current_tool_initial_sent:
|
||||
self.current_tool_initial_sent = True
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
InitialDeltaToolCall(
|
||||
index=self.current_tool_id).model_dump(
|
||||
exclude_none=True)
|
||||
])
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
new_text = delta_text.replace("\'", "\"")
|
||||
|
||||
if not cur_arguments and not prev_arguments:
|
||||
|
||||
delta = None
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset "
|
||||
"mid-arguments")
|
||||
delta = None
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments)
|
||||
logger.debug("finding %s in %s", new_text,
|
||||
cur_arguments_json)
|
||||
|
||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||
index(new_text) +
|
||||
len(new_text)]
|
||||
logger.debug("First tokens in arguments received: %s",
|
||||
arguments_delta)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += arguments_delta
|
||||
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
logger.debug("Searching for diff between \n%s\n%s",
|
||||
cur_args_json, prev_args_json)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
else:
|
||||
# try parsing it with regular JSON - if it works we're
|
||||
# at the end, and we need to send the difference between
|
||||
# tokens streamed so far and the valid JSON
|
||||
delta = None
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
87
vllm/entrypoints/openai/tool_parsers/utils.py
Normal file
87
vllm/entrypoints/openai/tool_parsers/utils.py
Normal file
@ -0,0 +1,87 @@
|
||||
def find_common_prefix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common prefix that is shared between two strings, if there is one.
|
||||
Order of arguments is NOT important.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely.
|
||||
|
||||
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
|
||||
'{"fruit": "ap'
|
||||
"""
|
||||
prefix = ''
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(0, min_length):
|
||||
if s1[i] == s2[i]:
|
||||
prefix += s1[i]
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def find_common_suffix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common suffix shared between two strings, if there is one. Order of
|
||||
arguments is NOT important.
|
||||
Stops when the suffix ends OR it hits an alphanumeric character
|
||||
|
||||
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
|
||||
"""
|
||||
suffix = ''
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(1, min_length + 1):
|
||||
if s1[-i] == s2[-i] and not s1[-i].isalnum():
|
||||
suffix = s1[-i] + suffix
|
||||
else:
|
||||
break
|
||||
return suffix
|
||||
|
||||
|
||||
def extract_intermediate_diff(curr: str, old: str) -> str:
|
||||
"""
|
||||
Given two strings, extract the difference in the middle between two strings
|
||||
that are known to have a common prefix and/or suffix.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely. The order of arguments IS
|
||||
important - the new version of the partially-parsed JSON must be the first
|
||||
argument, and the secnod argument must be from the previous generation.
|
||||
|
||||
What it returns, is tokens that should be streamed to the client.
|
||||
|
||||
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
|
||||
-> 'ple'
|
||||
|
||||
"""
|
||||
suffix = find_common_suffix(curr, old)
|
||||
|
||||
old = old[::-1].replace(suffix[::-1], '', 1)[::-1]
|
||||
prefix = find_common_prefix(curr, old)
|
||||
diff = curr
|
||||
if len(suffix):
|
||||
diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1]
|
||||
|
||||
if len(prefix):
|
||||
# replace the prefix only once in case it's mirrored
|
||||
diff = diff.replace(prefix, '', 1)
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
def find_all_indices(string, substring):
|
||||
"""
|
||||
Find all (starting) indices of a substring in a given string. Useful for
|
||||
tool call extraction
|
||||
"""
|
||||
indices = []
|
||||
index = -1
|
||||
while True:
|
||||
index = string.find(substring, index + 1)
|
||||
if index == -1:
|
||||
break
|
||||
indices.append(index)
|
||||
return indices
|
@ -59,8 +59,9 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest,
|
||||
if type(request) is CompletionRequest:
|
||||
return request
|
||||
|
||||
# user has chosen to not use any tool
|
||||
if request.tool_choice == "none":
|
||||
# user has chosen to not use any tool,
|
||||
# OR is allowing the model to choose a tool.
|
||||
if request.tool_choice == "none" or request.tool_choice == "auto":
|
||||
return request
|
||||
|
||||
# user has chosen to use a named tool
|
||||
|
@ -8,8 +8,9 @@ from typing import Tuple, Union
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest)
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
@ -101,16 +102,30 @@ def _get_guide_and_mode(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest,
|
||||
GuidedDecodingRequest]
|
||||
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
|
||||
# if the request is a chat completion request, AND the tool choice is a
|
||||
# named tool choice, do guided decoding
|
||||
# using that tool as the JSON schema
|
||||
if isinstance(request, ChatCompletionRequest) and isinstance(
|
||||
request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
# Guided generation for tools/functions parameters
|
||||
if request.tool_choice.type == "function":
|
||||
for tool in request.tools:
|
||||
if (tool.type == "function" and tool.function.name
|
||||
== request.tool_choice.function.name):
|
||||
json = json_dumps(tool.function.parameters, sort_keys=True)
|
||||
return json, GuidedDecodingMode.JSON
|
||||
return None, None
|
||||
|
||||
if request.guided_json:
|
||||
json = request.guided_json
|
||||
if isinstance(json, dict):
|
||||
elif request.guided_json:
|
||||
if isinstance(request.guided_json, dict):
|
||||
# turn dict into hashable string
|
||||
json = json_dumps(json)
|
||||
elif isinstance(json, BaseModel):
|
||||
json = json_dumps(request.guided_json)
|
||||
elif isinstance(request.guided_json, BaseModel):
|
||||
# use pydantic signature so that different model classes
|
||||
# with the same fields will get hashed the same
|
||||
json = str(json.__signature__)
|
||||
json = str(request.guided_json.__signature__)
|
||||
else:
|
||||
json = request.guided_json
|
||||
return json, GuidedDecodingMode.JSON
|
||||
elif request.guided_regex:
|
||||
return request.guided_regex, GuidedDecodingMode.REGEX
|
||||
|
Loading…
x
Reference in New Issue
Block a user