From d6cd59f122351a71e742e57221df2f26b3b2f9f9 Mon Sep 17 00:00:00 2001 From: Robin <863579016@qq.com> Date: Mon, 24 Mar 2025 05:00:07 +0800 Subject: [PATCH] [Frontend] Support tool calling and reasoning parser (#14511) Signed-off-by: WangErXiao <863579016@qq.com> --- .buildkite/test-pipeline.yaml | 2 +- docs/source/features/reasoning_outputs.md | 51 ++++- ...at_completion_tool_calls_with_reasoning.py | 177 +++++++++++++++++ .../openai/test_chat_with_tool_reasoning.py | 145 ++++++++++++++ vllm/entrypoints/openai/cli_args.py | 7 - .../abs_reasoning_parsers.py | 35 ++++ .../deepseek_r1_reasoning_parser.py | 13 ++ vllm/entrypoints/openai/serving_chat.py | 188 +++++++++++++----- 8 files changed, 555 insertions(+), 63 deletions(-) create mode 100644 examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py create mode 100644 tests/entrypoints/openai/test_chat_with_tool_reasoning.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7e812cbc..217f869f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -118,7 +118,7 @@ steps: - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/correctness/ + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ - pytest -v -s entrypoints/test_chat_utils.py - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index b5fad263..0b170aad 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -10,10 +10,10 @@ Reasoning models return a additional `reasoning_content` field in their outputs, vLLM currently supports the following reasoning models: -| Model Series | Parser Name | Structured Output Support | -|--------------|-------------|------------------| -| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | -| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | +| Model Series | Parser Name | Structured Output Support | Tool Calling | +|--------------|-------------|------------------|-------------| +| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ | +| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ | ## Quickstart @@ -170,10 +170,51 @@ print("reasoning_content: ", completion.choices[0].message.reasoning_content) print("content: ", completion.choices[0].message.content) ``` +## Tool Calling + +The reasoning content is also available when both tool calling and the reasoning parser are enabled. Additionally, tool calling only parses functions from the `content` field, not from the `reasoning_content`. + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy") + +tools = [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City and state, e.g., 'San Francisco, CA'"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} + }, + "required": ["location", "unit"] + } + } +}] + +response = client.chat.completions.create( + model=client.models.list().data[0].id, + messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}], + tools=tools, + tool_choice="auto" +) + +print(response) +tool_call = response.choices[0].message.tool_calls[0].function + +print(f"reasoning_content: {response.choices[0].message.reasoning_content}") +print(f"Function called: {tool_call.name}") +print(f"Arguments: {tool_call.arguments}") +``` + +For more examples, please refer to . + ## Limitations - The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). -- It is not compatible with [`tool_calling`](#tool_calling). ## How to support a new reasoning model diff --git a/examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py b/examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py new file mode 100644 index 00000000..9e7a69c6 --- /dev/null +++ b/examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +An example demonstrates how to use tool calling with reasoning models +like QwQ-32B. The reasoning_content will not be parsed by the tool +calling process; only the final output will be parsed. + +To run this example, you need to start the vLLM server with both +the reasoning parser and tool calling enabled. + +```bash +vllm serve Qwen/QwQ-32B \ + --enable-reasoning --reasoning-parser deepseek_r1 \ + --enable-auto-tool-choice --tool-call-parser hermes + +``` + +""" + +from openai import OpenAI + + +# Now, simulate a tool call +def get_current_weather(city: str, state: str, unit: 'str'): + return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " + "partly cloudly, with highs in the 90's.") + + +available_tools = {"get_current_weather": get_current_weather} + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + +tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } +}] +messages = [{ + "role": "user", + "content": "Hi! How are you doing today?" +}, { + "role": "assistant", + "content": "I'm doing well! How can I help you?" +}, { + "role": + "user", + "content": + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" +}] + + +def extract_reasoning_and_calls(chunks: list): + reasoning_content = "" + tool_call_idx = -1 + arguments = [] + function_names = [] + for chunk in chunks: + if chunk.choices[0].delta.tool_calls: + tool_call = chunk.choices[0].delta.tool_calls[0] + if tool_call.index != tool_call_idx: + tool_call_idx = chunk.choices[0].delta.tool_calls[0].index + arguments.append("") + function_names.append("") + + if tool_call.function: + if tool_call.function.name: + function_names[tool_call_idx] = tool_call.function.name + + if tool_call.function.arguments: + arguments[tool_call_idx] += tool_call.function.arguments + else: + if hasattr(chunk.choices[0].delta, "reasoning_content"): + reasoning_content += chunk.choices[0].delta.reasoning_content + return reasoning_content, arguments, function_names + + +print("---------Full Generate With Automatic Function Calling-------------") +tool_calls = client.chat.completions.create(messages=messages, + model=model, + tools=tools) +print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}") +print(f"function name: " + f"{tool_calls.choices[0].message.tool_calls[0].function.name}") +print(f"function arguments: " + f"{tool_calls.choices[0].message.tool_calls[0].function.arguments}") + +print("----------Stream Generate With Automatic Function Calling-----------") +tool_calls_stream = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + stream=True) +chunks = [] +for chunk in tool_calls_stream: + chunks.append(chunk) + +reasoning_content, arguments, function_names = extract_reasoning_and_calls( + chunks) + +print(f"reasoning_content: {reasoning_content}") +print(f"function name: {function_names[0]}") +print(f"function arguments: {arguments[0]}") + +print("----------Full Generate With Named Function Calling-----------------") +tool_calls = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + tool_choice={ + "type": "function", + "function": { + "name": + "get_current_weather" + } + }) + +tool_call = tool_calls.choices[0].message.tool_calls[0].function +print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}") +print(f"function name: {tool_call.name}") +print(f"function arguments: {tool_call.arguments}") +print("----------Stream Generate With Named Function Calling--------------") + +tool_calls_stream = client.chat.completions.create( + messages=messages, + model=model, + tools=tools, + tool_choice={ + "type": "function", + "function": { + "name": "get_current_weather" + } + }, + stream=True) + +chunks = [] +for chunk in tool_calls_stream: + chunks.append(chunk) + +reasoning_content, arguments, function_names = extract_reasoning_and_calls( + chunks) +print(f"reasoning_content: {reasoning_content}") +print(f"function name: {function_names[0]}") +print(f"function arguments: {arguments[0]}") +print("\n\n") diff --git a/tests/entrypoints/openai/test_chat_with_tool_reasoning.py b/tests/entrypoints/openai/test_chat_with_tool_reasoning.py new file mode 100644 index 00000000..53df1d92 --- /dev/null +++ b/tests/entrypoints/openai/test_chat_with_tool_reasoning.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +from ...utils import RemoteOpenAIServer + +# a reasoning and tool calling model +MODEL_NAME = "Qwen/QwQ-32B" + + +@pytest.fixture(scope="module") +def server(): # noqa: F811 + args = [ + "--max-model-len", "8192", "--enforce-eager", "--enable-reasoning", + "--reasoning-parser", "deepseek_r1", "--enable-auto-tool-choice", + "--tool-call-parser", "hermes" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +TOOLS = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } +}] + +MESSAGES = [{ + "role": "user", + "content": "Hi! How are you doing today?" +}, { + "role": "assistant", + "content": "I'm doing well! How can I help you?" +}, { + "role": + "user", + "content": + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" +}] + +FUNC_NAME = "get_current_weather" +FUNC_ARGS = """{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}""" + + +def extract_reasoning_and_calls(chunks: list): + reasoning_content = "" + tool_call_idx = -1 + arguments = [] + function_names = [] + for chunk in chunks: + if chunk.choices[0].delta.tool_calls: + tool_call = chunk.choices[0].delta.tool_calls[0] + if tool_call.index != tool_call_idx: + tool_call_idx = chunk.choices[0].delta.tool_calls[0].index + arguments.append("") + function_names.append("") + + if tool_call.function: + if tool_call.function.name: + function_names[tool_call_idx] = tool_call.function.name + + if tool_call.function.arguments: + arguments[tool_call_idx] += tool_call.function.arguments + else: + if hasattr(chunk.choices[0].delta, "reasoning_content"): + reasoning_content += chunk.choices[0].delta.reasoning_content + return reasoning_content, arguments, function_names + + +# test streaming +@pytest.mark.asyncio +async def test_chat_streaming_of_tool_and_reasoning( + client: openai.AsyncOpenAI): + + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES, + tools=TOOLS, + temperature=0.0, + stream=True, + ) + + chunks = [] + async for chunk in stream: + chunks.append(chunk) + + reasoning_content, arguments, function_names = extract_reasoning_and_calls( + chunks) + assert len(reasoning_content) > 0 + assert len(function_names) > 0 and function_names[0] == FUNC_NAME + assert len(arguments) > 0 and arguments[0] == FUNC_ARGS + + +# test full generate +@pytest.mark.asyncio +async def test_chat_full_of_tool_and_reasoning(client: openai.AsyncOpenAI): + + tool_calls = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES, + tools=TOOLS, + temperature=0.0, + stream=False, + ) + + assert len(tool_calls.choices[0].message.reasoning_content) > 0 + assert tool_calls.choices[0].message.tool_calls[0].function.name \ + == FUNC_NAME + assert tool_calls.choices[0].message.tool_calls[0].function.arguments \ + == FUNC_ARGS diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 01c67b8a..e956920c 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -289,13 +289,6 @@ def validate_parsed_serve_args(args: argparse.Namespace): raise TypeError("Error: --enable-reasoning requires " "--reasoning-parser") - # Ref https://api-docs.deepseek.com/guides/reasoning_model - # tool call and reasoning cannot be enabled at the same time. - if args.enable_auto_tool_choice and args.enable_reasoning: - raise TypeError( - "Error: --enable-auto-tool-choice and " - "--enable-reasoning cannot be enabled at the same time") - def create_parser_for_docs() -> FlexibleArgumentParser: parser_for_docs = FlexibleArgumentParser( diff --git a/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py b/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py index b3bc0e83..c95ff191 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py +++ b/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import os +from abc import abstractmethod from collections.abc import Sequence from functools import cached_property from typing import Callable, Optional, Union @@ -76,6 +77,40 @@ class ReasoningParser: "AbstractReasoningParser.extract_reasoning_content_streaming " "has not been implemented!") + # TODO: need to rebase by PR #14428 + @abstractmethod + def is_reasoning_end(self, input_ids: list[int]) -> bool: + """ + Check if the reasoning content ends in the input_ids. + Parameters: + input_ids: list[int] + The input_ids of the model output. + Returns: + bool + True if the reasoning content ends in the input_ids. + """ + + raise NotImplementedError( + "AbstractReasoningParser.is_reasoning_end has" + "not been implemented!") + + # TODO: need to rebase by PR #14428 + @abstractmethod + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """ + Extract content token ids from the input_ids. + Parameters: + input_ids: list[int] + The input_ids of the model output. + Returns: + list[int] + The extracted content from the input_ids. + """ + + raise NotImplementedError( + "AbstractReasoningParser.extract_content_ids has" + " not been implemented!") + class ReasoningParserManager: reasoning_parsers: dict[str, type] = {} diff --git a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py index 1a2c66a6..54e96016 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py +++ b/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py @@ -45,6 +45,19 @@ class DeepSeekR1ReasoningParser(ReasoningParser): "DeepSeek R1 reasoning parser could not locate think start/end " "tokens in the tokenizer!") + # TODO: need to rebase by PR #14428 + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.think_end_token_id in input_ids + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """ + Extract the content after the end tokens + """ + if self.think_end_token_id not in input_ids[:-1]: + return [] + else: + return input_ids[input_ids.index(self.think_end_token_id) + 1:] + def extract_reasoning_content_streaming( self, previous_text: str, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 130dfe18..3c35a848 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -328,6 +328,9 @@ class OpenAIServingChat(OpenAIServing): # These are only required in "auto" tool choice case previous_texts = [""] * num_choices all_previous_token_ids = [[]] * num_choices + # For reasoning parser and tool call all enabled + added_content_delta_arr = [False] * num_choices + reasoning_end_arr = [False] * num_choices else: previous_texts, all_previous_token_ids = None, None @@ -477,27 +480,116 @@ class OpenAIServingChat(OpenAIServing): delta_message: Optional[DeltaMessage] - # handle streaming deltas for tools with named tool_choice - if tool_choice_function_name: - delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(function=DeltaFunctionCall( - name=tool_choice_function_name, - arguments=delta_text), - index=i) - ]) - - # handle streaming deltas for tools with "auto" tool choice - elif tool_choice_auto: + # just update previous_texts and previous_token_ids + if tool_choice_auto or should_stream_with_reasoning_parsing: assert previous_texts is not None assert all_previous_token_ids is not None - assert tool_parser is not None - #TODO optimize manipulation of these lists previous_text = previous_texts[i] previous_token_ids = all_previous_token_ids[i] current_text = previous_text + delta_text current_token_ids = previous_token_ids + list( output.token_ids) + # handle streaming deltas for tools with named tool_choice + if tool_choice_function_name: + if (self.enable_reasoning + and not reasoning_parser.is_reasoning_end( + previous_token_ids)): + assert reasoning_parser is not None + delta_message = ( + reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) + # When encountering think end id in delta_token_ids, + # process the `content`. Only keep 'content', + # remove 'reasoning_content' + if reasoning_parser.is_reasoning_end( + list(output.token_ids)): + if delta_message and delta_message.content: + # This need to be added to next `delta_text` + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + else: + # Just to add remaining `content` + if self.enable_reasoning: + delta_text = previous_text + delta_text + current_text = "" + + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(function=DeltaFunctionCall( + name=tool_choice_function_name, + arguments=delta_text), + index=i) + ]) + + # handle streaming deltas for tools with "auto" tool choice + # and reasoning parser + elif tool_choice_auto and self.enable_reasoning: + assert tool_parser is not None + assert reasoning_parser is not None + assert added_content_delta_arr is not None + assert reasoning_end_arr is not None + if not reasoning_end_arr[i]: + delta_message = ( + reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) + + # When encountering think end id in delta_token_ids, + # set reasoning status to end. + # Remove the text and token ids related + # to 'reasoning_content'. + if reasoning_parser.is_reasoning_end( + list(output.token_ids)): + reasoning_end_arr[i] = True + current_token_ids = \ + reasoning_parser.extract_content_ids( + list(output.token_ids)) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + + # handle tool calls only after reasoning is done, + else: + delta_token_ids = list(output.token_ids) + # First time to tool call, + # add the remaining text and token ids + # to delta from previous + if not added_content_delta_arr[i]: + added_content_delta_arr[i] = True + previous_text = "" + previous_token_ids = [] + delta_text = current_text + delta_token_ids = current_token_ids + + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request)) + # when only tool calls + elif tool_choice_auto: + assert tool_parser is not None delta_message = ( tool_parser.extract_tool_calls_streaming( previous_text=previous_text, @@ -507,23 +599,9 @@ class OpenAIServingChat(OpenAIServing): current_token_ids=current_token_ids, delta_token_ids=output.token_ids, request=request)) - - # update the previous values for the next iteration - previous_texts[i] = current_text - all_previous_token_ids[i] = current_token_ids - # reasoning_content cannot be enabled with tool_choice. - # If it is, the tool_choice will be used instead. + # when only reasoning elif self.enable_reasoning: - # handle reasoning_content delta assert reasoning_parser is not None - assert previous_texts is not None - assert all_previous_token_ids is not None - previous_text = previous_texts[i] - previous_token_ids = all_previous_token_ids[i] - current_text = previous_text + delta_text - current_token_ids = previous_token_ids + list( - output.token_ids) - delta_message = (reasoning_parser. extract_reasoning_content_streaming( previous_text, @@ -533,15 +611,17 @@ class OpenAIServingChat(OpenAIServing): current_token_ids, output.token_ids, )) - - # update the previous values for the next iteration - previous_texts[i] = current_text - all_previous_token_ids[i] = current_token_ids - # handle streaming just a content delta else: delta_message = DeltaMessage(content=delta_text) + # update the previous values for the next iteration + if tool_choice_auto or should_stream_with_reasoning_parsing: + assert previous_texts is not None + assert all_previous_token_ids is not None + previous_texts[i] = current_text + all_previous_token_ids[i] = current_token_ids + # set the previous values for the next iteration previous_num_tokens[i] += len(output.token_ids) @@ -739,24 +819,24 @@ class OpenAIServingChat(OpenAIServing): except RuntimeError as e: logger.exception("Error in reasoning parser creation.") return self.create_error_response(str(e)) - + # If the reasoning parser is enabled, + # tool calls are extracted exclusively from the content. reasoning_content, content = ( reasoning_parser.extract_reasoning_content( output.text, request=request)) - - if reasoning_content: - message = ChatMessage(role=role, - content=content, - reasoning_content=reasoning_content) - else: - message = ChatMessage(role=role, content=output.text) + else: + reasoning_content = None + content = output.text # if auto tools are not enabled, and a named tool choice using # outlines is not being used - elif (not self.enable_auto_tools - or not self.tool_parser) and not isinstance( - request.tool_choice, ChatCompletionNamedToolChoiceParam): - message = ChatMessage(role=role, content=output.text) + if (not self.enable_auto_tools + or not self.tool_parser) and not isinstance( + request.tool_choice, + ChatCompletionNamedToolChoiceParam): + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) # if the request uses tools and specified a tool choice elif request.tool_choice and type( @@ -766,18 +846,21 @@ class OpenAIServingChat(OpenAIServing): tokenizer, MistralTokenizer) else ToolCall message = ChatMessage( role=role, + reasoning_content=reasoning_content, content="", tool_calls=[ tool_call_class(function=FunctionCall( name=request.tool_choice.function.name, - arguments=output.text)) + arguments=content)) ]) # if the request doesn't use tool choice # OR specifies to not use a tool elif not request.tool_choice or request.tool_choice == "none": - message = ChatMessage(role=role, content=output.text) + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) # handle when there are tools and tool choice is auto elif request.tools and ( @@ -792,20 +875,23 @@ class OpenAIServingChat(OpenAIServing): return self.create_error_response(str(e)) tool_call_info = tool_parser.extract_tool_calls( - output.text, request=request) + content if content is not None else "", request=request) # In the OpenAI API the finish_reason is "tools_called" # if the tool choice is auto and the model produced a tool # call. The same is not true for named function calls auto_tools_called = tool_call_info.tools_called if tool_call_info.tools_called: message = ChatMessage(role=role, + reasoning_content=reasoning_content, content=tool_call_info.content, tool_calls=tool_call_info.tool_calls) else: # FOR NOW make it a chat message; we will have to detect # the type to make it later. - message = ChatMessage(role=role, content=output.text) + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) # undetermined case that is still important to handle else: @@ -813,7 +899,9 @@ class OpenAIServingChat(OpenAIServing): "Error in chat_completion_full_generator - cannot determine" " if tools should be extracted. Returning a standard chat " "completion.") - message = ChatMessage(role=role, content=output.text) + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) choice_data = ChatCompletionResponseChoice( index=output.index,