[Frontend] Support tool calling and reasoning parser (#14511)
Signed-off-by: WangErXiao <863579016@qq.com>
This commit is contained in:
parent
bc8ed3c4ba
commit
d6cd59f122
@ -118,7 +118,7 @@ steps:
|
|||||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/correctness/
|
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/
|
||||||
- pytest -v -s entrypoints/test_chat_utils.py
|
- pytest -v -s entrypoints/test_chat_utils.py
|
||||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||||
|
|
||||||
|
@ -10,10 +10,10 @@ Reasoning models return a additional `reasoning_content` field in their outputs,
|
|||||||
|
|
||||||
vLLM currently supports the following reasoning models:
|
vLLM currently supports the following reasoning models:
|
||||||
|
|
||||||
| Model Series | Parser Name | Structured Output Support |
|
| Model Series | Parser Name | Structured Output Support | Tool Calling |
|
||||||
|--------------|-------------|------------------|
|
|--------------|-------------|------------------|-------------|
|
||||||
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` |
|
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ |
|
||||||
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` |
|
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ |
|
||||||
|
|
||||||
## Quickstart
|
## Quickstart
|
||||||
|
|
||||||
@ -170,10 +170,51 @@ print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
|||||||
print("content: ", completion.choices[0].message.content)
|
print("content: ", completion.choices[0].message.content)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Tool Calling
|
||||||
|
|
||||||
|
The reasoning content is also available when both tool calling and the reasoning parser are enabled. Additionally, tool calling only parses functions from the `content` field, not from the `reasoning_content`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy")
|
||||||
|
|
||||||
|
tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string", "description": "City and state, e.g., 'San Francisco, CA'"},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}
|
||||||
|
},
|
||||||
|
"required": ["location", "unit"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=client.models.list().data[0].id,
|
||||||
|
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
tool_call = response.choices[0].message.tool_calls[0].function
|
||||||
|
|
||||||
|
print(f"reasoning_content: {response.choices[0].message.reasoning_content}")
|
||||||
|
print(f"Function called: {tool_call.name}")
|
||||||
|
print(f"Arguments: {tool_call.arguments}")
|
||||||
|
```
|
||||||
|
|
||||||
|
For more examples, please refer to <gh-file:examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py> .
|
||||||
|
|
||||||
## Limitations
|
## Limitations
|
||||||
|
|
||||||
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
|
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
|
||||||
- It is not compatible with [`tool_calling`](#tool_calling).
|
|
||||||
|
|
||||||
## How to support a new reasoning model
|
## How to support a new reasoning model
|
||||||
|
|
||||||
|
@ -0,0 +1,177 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
An example demonstrates how to use tool calling with reasoning models
|
||||||
|
like QwQ-32B. The reasoning_content will not be parsed by the tool
|
||||||
|
calling process; only the final output will be parsed.
|
||||||
|
|
||||||
|
To run this example, you need to start the vLLM server with both
|
||||||
|
the reasoning parser and tool calling enabled.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
vllm serve Qwen/QwQ-32B \
|
||||||
|
--enable-reasoning --reasoning-parser deepseek_r1 \
|
||||||
|
--enable-auto-tool-choice --tool-call-parser hermes
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
# Now, simulate a tool call
|
||||||
|
def get_current_weather(city: str, state: str, unit: 'str'):
|
||||||
|
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
|
||||||
|
"partly cloudly, with highs in the 90's.")
|
||||||
|
|
||||||
|
|
||||||
|
available_tools = {"get_current_weather": get_current_weather}
|
||||||
|
|
||||||
|
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||||
|
openai_api_key = "EMPTY"
|
||||||
|
openai_api_base = "http://localhost:8000/v1"
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
models = client.models.list()
|
||||||
|
model = models.data[0].id
|
||||||
|
|
||||||
|
tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type":
|
||||||
|
"string",
|
||||||
|
"description":
|
||||||
|
"The city to find the weather for, e.g. 'San Francisco'"
|
||||||
|
},
|
||||||
|
"state": {
|
||||||
|
"type":
|
||||||
|
"string",
|
||||||
|
"description":
|
||||||
|
"the two-letter abbreviation for the state that the city is"
|
||||||
|
" in, e.g. 'CA' which would mean 'California'"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The unit to fetch the temperature in",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city", "state", "unit"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
messages = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi! How are you doing today?"
|
||||||
|
}, {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'm doing well! How can I help you?"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_reasoning_and_calls(chunks: list):
|
||||||
|
reasoning_content = ""
|
||||||
|
tool_call_idx = -1
|
||||||
|
arguments = []
|
||||||
|
function_names = []
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk.choices[0].delta.tool_calls:
|
||||||
|
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||||
|
if tool_call.index != tool_call_idx:
|
||||||
|
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
|
||||||
|
arguments.append("")
|
||||||
|
function_names.append("")
|
||||||
|
|
||||||
|
if tool_call.function:
|
||||||
|
if tool_call.function.name:
|
||||||
|
function_names[tool_call_idx] = tool_call.function.name
|
||||||
|
|
||||||
|
if tool_call.function.arguments:
|
||||||
|
arguments[tool_call_idx] += tool_call.function.arguments
|
||||||
|
else:
|
||||||
|
if hasattr(chunk.choices[0].delta, "reasoning_content"):
|
||||||
|
reasoning_content += chunk.choices[0].delta.reasoning_content
|
||||||
|
return reasoning_content, arguments, function_names
|
||||||
|
|
||||||
|
|
||||||
|
print("---------Full Generate With Automatic Function Calling-------------")
|
||||||
|
tool_calls = client.chat.completions.create(messages=messages,
|
||||||
|
model=model,
|
||||||
|
tools=tools)
|
||||||
|
print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}")
|
||||||
|
print(f"function name: "
|
||||||
|
f"{tool_calls.choices[0].message.tool_calls[0].function.name}")
|
||||||
|
print(f"function arguments: "
|
||||||
|
f"{tool_calls.choices[0].message.tool_calls[0].function.arguments}")
|
||||||
|
|
||||||
|
print("----------Stream Generate With Automatic Function Calling-----------")
|
||||||
|
tool_calls_stream = client.chat.completions.create(messages=messages,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
stream=True)
|
||||||
|
chunks = []
|
||||||
|
for chunk in tool_calls_stream:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
reasoning_content, arguments, function_names = extract_reasoning_and_calls(
|
||||||
|
chunks)
|
||||||
|
|
||||||
|
print(f"reasoning_content: {reasoning_content}")
|
||||||
|
print(f"function name: {function_names[0]}")
|
||||||
|
print(f"function arguments: {arguments[0]}")
|
||||||
|
|
||||||
|
print("----------Full Generate With Named Function Calling-----------------")
|
||||||
|
tool_calls = client.chat.completions.create(messages=messages,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice={
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name":
|
||||||
|
"get_current_weather"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
tool_call = tool_calls.choices[0].message.tool_calls[0].function
|
||||||
|
print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}")
|
||||||
|
print(f"function name: {tool_call.name}")
|
||||||
|
print(f"function arguments: {tool_call.arguments}")
|
||||||
|
print("----------Stream Generate With Named Function Calling--------------")
|
||||||
|
|
||||||
|
tool_calls_stream = client.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice={
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
stream=True)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for chunk in tool_calls_stream:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
reasoning_content, arguments, function_names = extract_reasoning_and_calls(
|
||||||
|
chunks)
|
||||||
|
print(f"reasoning_content: {reasoning_content}")
|
||||||
|
print(f"function name: {function_names[0]}")
|
||||||
|
print(f"function arguments: {arguments[0]}")
|
||||||
|
print("\n\n")
|
145
tests/entrypoints/openai/test_chat_with_tool_reasoning.py
Normal file
145
tests/entrypoints/openai/test_chat_with_tool_reasoning.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import openai # use the official client for correctness check
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
# a reasoning and tool calling model
|
||||||
|
MODEL_NAME = "Qwen/QwQ-32B"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server(): # noqa: F811
|
||||||
|
args = [
|
||||||
|
"--max-model-len", "8192", "--enforce-eager", "--enable-reasoning",
|
||||||
|
"--reasoning-parser", "deepseek_r1", "--enable-auto-tool-choice",
|
||||||
|
"--tool-call-parser", "hermes"
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(server):
|
||||||
|
async with server.get_async_client() as async_client:
|
||||||
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
|
TOOLS = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type":
|
||||||
|
"string",
|
||||||
|
"description":
|
||||||
|
"The city to find the weather for, e.g. 'San Francisco'"
|
||||||
|
},
|
||||||
|
"state": {
|
||||||
|
"type":
|
||||||
|
"string",
|
||||||
|
"description":
|
||||||
|
"the two-letter abbreviation for the state that the city is"
|
||||||
|
" in, e.g. 'CA' which would mean 'California'"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The unit to fetch the temperature in",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city", "state", "unit"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
MESSAGES = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi! How are you doing today?"
|
||||||
|
}, {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'm doing well! How can I help you?"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
FUNC_NAME = "get_current_weather"
|
||||||
|
FUNC_ARGS = """{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}"""
|
||||||
|
|
||||||
|
|
||||||
|
def extract_reasoning_and_calls(chunks: list):
|
||||||
|
reasoning_content = ""
|
||||||
|
tool_call_idx = -1
|
||||||
|
arguments = []
|
||||||
|
function_names = []
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk.choices[0].delta.tool_calls:
|
||||||
|
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||||
|
if tool_call.index != tool_call_idx:
|
||||||
|
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
|
||||||
|
arguments.append("")
|
||||||
|
function_names.append("")
|
||||||
|
|
||||||
|
if tool_call.function:
|
||||||
|
if tool_call.function.name:
|
||||||
|
function_names[tool_call_idx] = tool_call.function.name
|
||||||
|
|
||||||
|
if tool_call.function.arguments:
|
||||||
|
arguments[tool_call_idx] += tool_call.function.arguments
|
||||||
|
else:
|
||||||
|
if hasattr(chunk.choices[0].delta, "reasoning_content"):
|
||||||
|
reasoning_content += chunk.choices[0].delta.reasoning_content
|
||||||
|
return reasoning_content, arguments, function_names
|
||||||
|
|
||||||
|
|
||||||
|
# test streaming
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_streaming_of_tool_and_reasoning(
|
||||||
|
client: openai.AsyncOpenAI):
|
||||||
|
|
||||||
|
stream = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=MESSAGES,
|
||||||
|
tools=TOOLS,
|
||||||
|
temperature=0.0,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk in stream:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
reasoning_content, arguments, function_names = extract_reasoning_and_calls(
|
||||||
|
chunks)
|
||||||
|
assert len(reasoning_content) > 0
|
||||||
|
assert len(function_names) > 0 and function_names[0] == FUNC_NAME
|
||||||
|
assert len(arguments) > 0 and arguments[0] == FUNC_ARGS
|
||||||
|
|
||||||
|
|
||||||
|
# test full generate
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_full_of_tool_and_reasoning(client: openai.AsyncOpenAI):
|
||||||
|
|
||||||
|
tool_calls = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=MESSAGES,
|
||||||
|
tools=TOOLS,
|
||||||
|
temperature=0.0,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(tool_calls.choices[0].message.reasoning_content) > 0
|
||||||
|
assert tool_calls.choices[0].message.tool_calls[0].function.name \
|
||||||
|
== FUNC_NAME
|
||||||
|
assert tool_calls.choices[0].message.tool_calls[0].function.arguments \
|
||||||
|
== FUNC_ARGS
|
@ -289,13 +289,6 @@ def validate_parsed_serve_args(args: argparse.Namespace):
|
|||||||
raise TypeError("Error: --enable-reasoning requires "
|
raise TypeError("Error: --enable-reasoning requires "
|
||||||
"--reasoning-parser")
|
"--reasoning-parser")
|
||||||
|
|
||||||
# Ref https://api-docs.deepseek.com/guides/reasoning_model
|
|
||||||
# tool call and reasoning cannot be enabled at the same time.
|
|
||||||
if args.enable_auto_tool_choice and args.enable_reasoning:
|
|
||||||
raise TypeError(
|
|
||||||
"Error: --enable-auto-tool-choice and "
|
|
||||||
"--enable-reasoning cannot be enabled at the same time")
|
|
||||||
|
|
||||||
|
|
||||||
def create_parser_for_docs() -> FlexibleArgumentParser:
|
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||||
parser_for_docs = FlexibleArgumentParser(
|
parser_for_docs = FlexibleArgumentParser(
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from abc import abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
@ -76,6 +77,40 @@ class ReasoningParser:
|
|||||||
"AbstractReasoningParser.extract_reasoning_content_streaming "
|
"AbstractReasoningParser.extract_reasoning_content_streaming "
|
||||||
"has not been implemented!")
|
"has not been implemented!")
|
||||||
|
|
||||||
|
# TODO: need to rebase by PR #14428
|
||||||
|
@abstractmethod
|
||||||
|
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the reasoning content ends in the input_ids.
|
||||||
|
Parameters:
|
||||||
|
input_ids: list[int]
|
||||||
|
The input_ids of the model output.
|
||||||
|
Returns:
|
||||||
|
bool
|
||||||
|
True if the reasoning content ends in the input_ids.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
"AbstractReasoningParser.is_reasoning_end has"
|
||||||
|
"not been implemented!")
|
||||||
|
|
||||||
|
# TODO: need to rebase by PR #14428
|
||||||
|
@abstractmethod
|
||||||
|
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||||
|
"""
|
||||||
|
Extract content token ids from the input_ids.
|
||||||
|
Parameters:
|
||||||
|
input_ids: list[int]
|
||||||
|
The input_ids of the model output.
|
||||||
|
Returns:
|
||||||
|
list[int]
|
||||||
|
The extracted content from the input_ids.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
"AbstractReasoningParser.extract_content_ids has"
|
||||||
|
" not been implemented!")
|
||||||
|
|
||||||
|
|
||||||
class ReasoningParserManager:
|
class ReasoningParserManager:
|
||||||
reasoning_parsers: dict[str, type] = {}
|
reasoning_parsers: dict[str, type] = {}
|
||||||
|
@ -45,6 +45,19 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
|
|||||||
"DeepSeek R1 reasoning parser could not locate think start/end "
|
"DeepSeek R1 reasoning parser could not locate think start/end "
|
||||||
"tokens in the tokenizer!")
|
"tokens in the tokenizer!")
|
||||||
|
|
||||||
|
# TODO: need to rebase by PR #14428
|
||||||
|
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||||
|
return self.think_end_token_id in input_ids
|
||||||
|
|
||||||
|
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||||
|
"""
|
||||||
|
Extract the content after the end tokens
|
||||||
|
"""
|
||||||
|
if self.think_end_token_id not in input_ids[:-1]:
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
|
||||||
|
|
||||||
def extract_reasoning_content_streaming(
|
def extract_reasoning_content_streaming(
|
||||||
self,
|
self,
|
||||||
previous_text: str,
|
previous_text: str,
|
||||||
|
@ -328,6 +328,9 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
# These are only required in "auto" tool choice case
|
# These are only required in "auto" tool choice case
|
||||||
previous_texts = [""] * num_choices
|
previous_texts = [""] * num_choices
|
||||||
all_previous_token_ids = [[]] * num_choices
|
all_previous_token_ids = [[]] * num_choices
|
||||||
|
# For reasoning parser and tool call all enabled
|
||||||
|
added_content_delta_arr = [False] * num_choices
|
||||||
|
reasoning_end_arr = [False] * num_choices
|
||||||
else:
|
else:
|
||||||
previous_texts, all_previous_token_ids = None, None
|
previous_texts, all_previous_token_ids = None, None
|
||||||
|
|
||||||
@ -477,27 +480,116 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
delta_message: Optional[DeltaMessage]
|
delta_message: Optional[DeltaMessage]
|
||||||
|
|
||||||
# handle streaming deltas for tools with named tool_choice
|
# just update previous_texts and previous_token_ids
|
||||||
if tool_choice_function_name:
|
if tool_choice_auto or should_stream_with_reasoning_parsing:
|
||||||
delta_message = DeltaMessage(tool_calls=[
|
|
||||||
DeltaToolCall(function=DeltaFunctionCall(
|
|
||||||
name=tool_choice_function_name,
|
|
||||||
arguments=delta_text),
|
|
||||||
index=i)
|
|
||||||
])
|
|
||||||
|
|
||||||
# handle streaming deltas for tools with "auto" tool choice
|
|
||||||
elif tool_choice_auto:
|
|
||||||
assert previous_texts is not None
|
assert previous_texts is not None
|
||||||
assert all_previous_token_ids is not None
|
assert all_previous_token_ids is not None
|
||||||
assert tool_parser is not None
|
|
||||||
#TODO optimize manipulation of these lists
|
|
||||||
previous_text = previous_texts[i]
|
previous_text = previous_texts[i]
|
||||||
previous_token_ids = all_previous_token_ids[i]
|
previous_token_ids = all_previous_token_ids[i]
|
||||||
current_text = previous_text + delta_text
|
current_text = previous_text + delta_text
|
||||||
current_token_ids = previous_token_ids + list(
|
current_token_ids = previous_token_ids + list(
|
||||||
output.token_ids)
|
output.token_ids)
|
||||||
|
|
||||||
|
# handle streaming deltas for tools with named tool_choice
|
||||||
|
if tool_choice_function_name:
|
||||||
|
if (self.enable_reasoning
|
||||||
|
and not reasoning_parser.is_reasoning_end(
|
||||||
|
previous_token_ids)):
|
||||||
|
assert reasoning_parser is not None
|
||||||
|
delta_message = (
|
||||||
|
reasoning_parser.
|
||||||
|
extract_reasoning_content_streaming(
|
||||||
|
previous_text,
|
||||||
|
current_text,
|
||||||
|
delta_text,
|
||||||
|
previous_token_ids,
|
||||||
|
current_token_ids,
|
||||||
|
output.token_ids,
|
||||||
|
))
|
||||||
|
# When encountering think end id in delta_token_ids,
|
||||||
|
# process the `content`. Only keep 'content',
|
||||||
|
# remove 'reasoning_content'
|
||||||
|
if reasoning_parser.is_reasoning_end(
|
||||||
|
list(output.token_ids)):
|
||||||
|
if delta_message and delta_message.content:
|
||||||
|
# This need to be added to next `delta_text`
|
||||||
|
current_text = delta_message.content
|
||||||
|
delta_message.content = None
|
||||||
|
else:
|
||||||
|
current_text = ""
|
||||||
|
else:
|
||||||
|
# Just to add remaining `content`
|
||||||
|
if self.enable_reasoning:
|
||||||
|
delta_text = previous_text + delta_text
|
||||||
|
current_text = ""
|
||||||
|
|
||||||
|
delta_message = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(function=DeltaFunctionCall(
|
||||||
|
name=tool_choice_function_name,
|
||||||
|
arguments=delta_text),
|
||||||
|
index=i)
|
||||||
|
])
|
||||||
|
|
||||||
|
# handle streaming deltas for tools with "auto" tool choice
|
||||||
|
# and reasoning parser
|
||||||
|
elif tool_choice_auto and self.enable_reasoning:
|
||||||
|
assert tool_parser is not None
|
||||||
|
assert reasoning_parser is not None
|
||||||
|
assert added_content_delta_arr is not None
|
||||||
|
assert reasoning_end_arr is not None
|
||||||
|
if not reasoning_end_arr[i]:
|
||||||
|
delta_message = (
|
||||||
|
reasoning_parser.
|
||||||
|
extract_reasoning_content_streaming(
|
||||||
|
previous_text,
|
||||||
|
current_text,
|
||||||
|
delta_text,
|
||||||
|
previous_token_ids,
|
||||||
|
current_token_ids,
|
||||||
|
output.token_ids,
|
||||||
|
))
|
||||||
|
|
||||||
|
# When encountering think end id in delta_token_ids,
|
||||||
|
# set reasoning status to end.
|
||||||
|
# Remove the text and token ids related
|
||||||
|
# to 'reasoning_content'.
|
||||||
|
if reasoning_parser.is_reasoning_end(
|
||||||
|
list(output.token_ids)):
|
||||||
|
reasoning_end_arr[i] = True
|
||||||
|
current_token_ids = \
|
||||||
|
reasoning_parser.extract_content_ids(
|
||||||
|
list(output.token_ids))
|
||||||
|
if delta_message and delta_message.content:
|
||||||
|
current_text = delta_message.content
|
||||||
|
delta_message.content = None
|
||||||
|
else:
|
||||||
|
current_text = ""
|
||||||
|
|
||||||
|
# handle tool calls only after reasoning is done,
|
||||||
|
else:
|
||||||
|
delta_token_ids = list(output.token_ids)
|
||||||
|
# First time to tool call,
|
||||||
|
# add the remaining text and token ids
|
||||||
|
# to delta from previous
|
||||||
|
if not added_content_delta_arr[i]:
|
||||||
|
added_content_delta_arr[i] = True
|
||||||
|
previous_text = ""
|
||||||
|
previous_token_ids = []
|
||||||
|
delta_text = current_text
|
||||||
|
delta_token_ids = current_token_ids
|
||||||
|
|
||||||
|
delta_message = (
|
||||||
|
tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=previous_text,
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=delta_text,
|
||||||
|
previous_token_ids=previous_token_ids,
|
||||||
|
current_token_ids=current_token_ids,
|
||||||
|
delta_token_ids=delta_token_ids,
|
||||||
|
request=request))
|
||||||
|
# when only tool calls
|
||||||
|
elif tool_choice_auto:
|
||||||
|
assert tool_parser is not None
|
||||||
delta_message = (
|
delta_message = (
|
||||||
tool_parser.extract_tool_calls_streaming(
|
tool_parser.extract_tool_calls_streaming(
|
||||||
previous_text=previous_text,
|
previous_text=previous_text,
|
||||||
@ -507,23 +599,9 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
current_token_ids=current_token_ids,
|
current_token_ids=current_token_ids,
|
||||||
delta_token_ids=output.token_ids,
|
delta_token_ids=output.token_ids,
|
||||||
request=request))
|
request=request))
|
||||||
|
# when only reasoning
|
||||||
# update the previous values for the next iteration
|
|
||||||
previous_texts[i] = current_text
|
|
||||||
all_previous_token_ids[i] = current_token_ids
|
|
||||||
# reasoning_content cannot be enabled with tool_choice.
|
|
||||||
# If it is, the tool_choice will be used instead.
|
|
||||||
elif self.enable_reasoning:
|
elif self.enable_reasoning:
|
||||||
# handle reasoning_content delta
|
|
||||||
assert reasoning_parser is not None
|
assert reasoning_parser is not None
|
||||||
assert previous_texts is not None
|
|
||||||
assert all_previous_token_ids is not None
|
|
||||||
previous_text = previous_texts[i]
|
|
||||||
previous_token_ids = all_previous_token_ids[i]
|
|
||||||
current_text = previous_text + delta_text
|
|
||||||
current_token_ids = previous_token_ids + list(
|
|
||||||
output.token_ids)
|
|
||||||
|
|
||||||
delta_message = (reasoning_parser.
|
delta_message = (reasoning_parser.
|
||||||
extract_reasoning_content_streaming(
|
extract_reasoning_content_streaming(
|
||||||
previous_text,
|
previous_text,
|
||||||
@ -533,15 +611,17 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
current_token_ids,
|
current_token_ids,
|
||||||
output.token_ids,
|
output.token_ids,
|
||||||
))
|
))
|
||||||
|
|
||||||
# update the previous values for the next iteration
|
|
||||||
previous_texts[i] = current_text
|
|
||||||
all_previous_token_ids[i] = current_token_ids
|
|
||||||
|
|
||||||
# handle streaming just a content delta
|
# handle streaming just a content delta
|
||||||
else:
|
else:
|
||||||
delta_message = DeltaMessage(content=delta_text)
|
delta_message = DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
|
# update the previous values for the next iteration
|
||||||
|
if tool_choice_auto or should_stream_with_reasoning_parsing:
|
||||||
|
assert previous_texts is not None
|
||||||
|
assert all_previous_token_ids is not None
|
||||||
|
previous_texts[i] = current_text
|
||||||
|
all_previous_token_ids[i] = current_token_ids
|
||||||
|
|
||||||
# set the previous values for the next iteration
|
# set the previous values for the next iteration
|
||||||
previous_num_tokens[i] += len(output.token_ids)
|
previous_num_tokens[i] += len(output.token_ids)
|
||||||
|
|
||||||
@ -739,24 +819,24 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.exception("Error in reasoning parser creation.")
|
logger.exception("Error in reasoning parser creation.")
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
# If the reasoning parser is enabled,
|
||||||
|
# tool calls are extracted exclusively from the content.
|
||||||
reasoning_content, content = (
|
reasoning_content, content = (
|
||||||
reasoning_parser.extract_reasoning_content(
|
reasoning_parser.extract_reasoning_content(
|
||||||
output.text, request=request))
|
output.text, request=request))
|
||||||
|
else:
|
||||||
if reasoning_content:
|
reasoning_content = None
|
||||||
message = ChatMessage(role=role,
|
content = output.text
|
||||||
content=content,
|
|
||||||
reasoning_content=reasoning_content)
|
|
||||||
else:
|
|
||||||
message = ChatMessage(role=role, content=output.text)
|
|
||||||
|
|
||||||
# if auto tools are not enabled, and a named tool choice using
|
# if auto tools are not enabled, and a named tool choice using
|
||||||
# outlines is not being used
|
# outlines is not being used
|
||||||
elif (not self.enable_auto_tools
|
if (not self.enable_auto_tools
|
||||||
or not self.tool_parser) and not isinstance(
|
or not self.tool_parser) and not isinstance(
|
||||||
request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
request.tool_choice,
|
||||||
message = ChatMessage(role=role, content=output.text)
|
ChatCompletionNamedToolChoiceParam):
|
||||||
|
message = ChatMessage(role=role,
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
content=content)
|
||||||
|
|
||||||
# if the request uses tools and specified a tool choice
|
# if the request uses tools and specified a tool choice
|
||||||
elif request.tool_choice and type(
|
elif request.tool_choice and type(
|
||||||
@ -766,18 +846,21 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
tokenizer, MistralTokenizer) else ToolCall
|
tokenizer, MistralTokenizer) else ToolCall
|
||||||
message = ChatMessage(
|
message = ChatMessage(
|
||||||
role=role,
|
role=role,
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
tool_call_class(function=FunctionCall(
|
tool_call_class(function=FunctionCall(
|
||||||
name=request.tool_choice.function.name,
|
name=request.tool_choice.function.name,
|
||||||
arguments=output.text))
|
arguments=content))
|
||||||
])
|
])
|
||||||
|
|
||||||
# if the request doesn't use tool choice
|
# if the request doesn't use tool choice
|
||||||
# OR specifies to not use a tool
|
# OR specifies to not use a tool
|
||||||
elif not request.tool_choice or request.tool_choice == "none":
|
elif not request.tool_choice or request.tool_choice == "none":
|
||||||
|
|
||||||
message = ChatMessage(role=role, content=output.text)
|
message = ChatMessage(role=role,
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
content=content)
|
||||||
|
|
||||||
# handle when there are tools and tool choice is auto
|
# handle when there are tools and tool choice is auto
|
||||||
elif request.tools and (
|
elif request.tools and (
|
||||||
@ -792,20 +875,23 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
tool_call_info = tool_parser.extract_tool_calls(
|
tool_call_info = tool_parser.extract_tool_calls(
|
||||||
output.text, request=request)
|
content if content is not None else "", request=request)
|
||||||
# In the OpenAI API the finish_reason is "tools_called"
|
# In the OpenAI API the finish_reason is "tools_called"
|
||||||
# if the tool choice is auto and the model produced a tool
|
# if the tool choice is auto and the model produced a tool
|
||||||
# call. The same is not true for named function calls
|
# call. The same is not true for named function calls
|
||||||
auto_tools_called = tool_call_info.tools_called
|
auto_tools_called = tool_call_info.tools_called
|
||||||
if tool_call_info.tools_called:
|
if tool_call_info.tools_called:
|
||||||
message = ChatMessage(role=role,
|
message = ChatMessage(role=role,
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
content=tool_call_info.content,
|
content=tool_call_info.content,
|
||||||
tool_calls=tool_call_info.tool_calls)
|
tool_calls=tool_call_info.tool_calls)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# FOR NOW make it a chat message; we will have to detect
|
# FOR NOW make it a chat message; we will have to detect
|
||||||
# the type to make it later.
|
# the type to make it later.
|
||||||
message = ChatMessage(role=role, content=output.text)
|
message = ChatMessage(role=role,
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
content=content)
|
||||||
|
|
||||||
# undetermined case that is still important to handle
|
# undetermined case that is still important to handle
|
||||||
else:
|
else:
|
||||||
@ -813,7 +899,9 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
"Error in chat_completion_full_generator - cannot determine"
|
"Error in chat_completion_full_generator - cannot determine"
|
||||||
" if tools should be extracted. Returning a standard chat "
|
" if tools should be extracted. Returning a standard chat "
|
||||||
"completion.")
|
"completion.")
|
||||||
message = ChatMessage(role=role, content=output.text)
|
message = ChatMessage(role=role,
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
content=content)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
index=output.index,
|
index=output.index,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user