From a164aea35d48a4ef7316c203ef89720b0267d9b4 Mon Sep 17 00:00:00 2001 From: Kinfey <93169410+kinfey@users.noreply.github.com> Date: Tue, 1 Apr 2025 13:50:05 +0800 Subject: [PATCH] [Frontend] Add Phi-4-mini function calling support (#14886) Signed-off-by: Kinfey Co-authored-by: Cyrus Leung --- examples/tool_chat_template_phi4_mini.jinja | 60 ++++++++++ .../openai/tool_parsers/__init__.py | 3 +- .../tool_parsers/phi4mini_tool_parser.py | 108 ++++++++++++++++++ 3 files changed, 170 insertions(+), 1 deletion(-) create mode 100644 examples/tool_chat_template_phi4_mini.jinja create mode 100644 vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py diff --git a/examples/tool_chat_template_phi4_mini.jinja b/examples/tool_chat_template_phi4_mini.jinja new file mode 100644 index 00000000..36423b6c --- /dev/null +++ b/examples/tool_chat_template_phi4_mini.jinja @@ -0,0 +1,60 @@ +{%- if messages %} + {%- if system_message or tools %} +<|system|> + +{%- if system_message %} +{{ system_message }} +{%- endif %} +In addition to plain text responses, you can chose to call one or more of the provided functions. + +Use the following rule to decide when to call a function: + * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so + * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls + +If you decide to call functions: + * prefix function calls with functools marker (no closing marker required) + * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] + * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples + * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 + * make sure you pick the right functions that match the user intent + + +{%- if tools %} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %}<|end|> + {%- endif %} + + {%- for message in messages %} + {%- if message.role != "system" %} +<|{{ message.role }}|> + {%- if message.content and message.role == "tools" %} +{"result": {{ message.content }}} + {%- elif message.content %} +{{ message.content }} + {%- elif message.tool_calls %} + {%- for call in message.tool_calls %} +{"name": "{{ call.function.name }}", "arguments": {{ call.function.arguments }}} + {%- if not loop.last %},{% endif %} + {%- endfor %} + {%- endif %}<|end|> + {%- endif %} + {%- endfor %}<|assistant|> + +{%- else %} + {%- if system_message %} +<|system|> + +{{ system_message }}<|end|> + {%- endif %} + {%- if prompt %} +<|user|> + +{{ prompt }}<|end|> + {%- endif %}<|assistant|> + +{%- endif %} +{{ response }} +{%- if response %}<|user|>{% endif %} \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index d1c3afa6..b81dc4e7 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -8,11 +8,12 @@ from .internlm2_tool_parser import Internlm2ToolParser from .jamba_tool_parser import JambaToolParser from .llama_tool_parser import Llama3JsonToolParser from .mistral_tool_parser import MistralToolParser +from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser __all__ = [ "ToolParser", "ToolParserManager", "Granite20bFCToolParser", "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", - "PythonicToolParser" + "PythonicToolParser", "Phi4MiniJsonToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py new file mode 100644 index 00000000..167eb0ea --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import re +from collections.abc import Sequence +from typing import Any, Optional + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("phi4_mini_json") +class Phi4MiniJsonToolParser(ToolParser): + """ + Tool call parser for phi-4-mini models intended for use with the + examples/tool_chat_template_llama.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser phi4_mini_json + are all set + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None: + super().__init__(tokenizer) + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: list[dict[str, Any]] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: list[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token: str = "functools" + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + print(f"Model output: {model_output}") + + pattern = r'functools\[(.*?)\]' + matches = re.search(pattern, model_output, re.DOTALL) + + if not matches: + print("No function calls found") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + function_call_arr: list[dict[str, Any]] = [] + try: + json_content = '[' + matches.group(1) + ']' + + function_call_arr = json.loads(json_content) + print(f"Successfully extracted {len(function_call_arr)} " + "function calls") + except json.JSONDecodeError as e: + print(f"Error parsing JSON: {e}") + + tool_calls: list[ToolCall] = [ + ToolCall( + id=f"chatcmpl-tool-{random_uuid()}", + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps( + raw_function_call["arguments"] if "arguments" in + raw_function_call else + raw_function_call["parameters"]))) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + ret = ExtractedToolCallInformation(tools_called=True, + tool_calls=tool_calls, + content=None) + return ret + + except Exception: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Optional[DeltaMessage]: + + return None