[Frontend][Feature] support tool calling for internlm/internlm2_5-7b-chat model (#8405)
This commit is contained in:
parent
2838d6b38e
commit
3dbb215b38
@ -12,4 +12,5 @@ torch
|
|||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
transformers
|
transformers
|
||||||
mistral_common >= 1.3.4
|
mistral_common >= 1.3.4
|
||||||
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||||
|
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
@ -157,8 +157,9 @@ vLLM will use guided decoding to ensure the response matches the tool parameter
|
|||||||
To enable this feature, you should set the following flags:
|
To enable this feature, you should set the following flags:
|
||||||
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
|
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
|
||||||
deems appropriate.
|
deems appropriate.
|
||||||
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes`, `mistral` or `llama3_json`. Additional tool parsers
|
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral` or `llama3_json` or `internlm`. Additional tool parsers
|
||||||
will continue to be added in the future.
|
will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`.
|
||||||
|
* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`.
|
||||||
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
|
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
|
||||||
that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their
|
that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their
|
||||||
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
|
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
|
||||||
@ -218,4 +219,73 @@ it works better with vLLM.
|
|||||||
|
|
||||||
Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja`
|
Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja`
|
||||||
|
|
||||||
|
#### Internlm Models
|
||||||
|
Supported models:
|
||||||
|
* `internlm/internlm2_5-7b-chat` (confirmed)
|
||||||
|
* Additional internlm2.5 function-calling models are compatible as well
|
||||||
|
|
||||||
|
Known issues:
|
||||||
|
* Although this implementation also supports Internlm2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model.
|
||||||
|
|
||||||
|
Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja`
|
||||||
|
|
||||||
|
|
||||||
|
### How to write a tool parser plugin
|
||||||
|
|
||||||
|
A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py.
|
||||||
|
|
||||||
|
Here is a summary of a plugin file:
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
# import the required packages
|
||||||
|
|
||||||
|
# define a tool parser and register it to vllm
|
||||||
|
# the name list in register_module can be used
|
||||||
|
# in --tool-call-parser. you can define as many
|
||||||
|
# tool parsers as you want here.
|
||||||
|
@ToolParserManager.register_module(["example"])
|
||||||
|
class ExampleToolParser(ToolParser):
|
||||||
|
def __init__(self, tokenizer: AnyTokenizer):
|
||||||
|
super().__init__(tokenizer)
|
||||||
|
|
||||||
|
# adjust request. e.g.: set skip special tokens
|
||||||
|
# to False for tool call output.
|
||||||
|
def adjust_request(
|
||||||
|
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||||
|
return request
|
||||||
|
|
||||||
|
# implement the tool call parse for stream call
|
||||||
|
def extract_tool_calls_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Union[DeltaMessage, None]:
|
||||||
|
return delta
|
||||||
|
|
||||||
|
# implement the tool parse for non-stream call
|
||||||
|
def extract_tool_calls(
|
||||||
|
self,
|
||||||
|
model_output: str,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> ExtractedToolCallInformation:
|
||||||
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=text)
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Then you can use this plugin in the command line like this.
|
||||||
|
```
|
||||||
|
--enable-auto-tool-choice \
|
||||||
|
--tool-parser-plugin <absolute path of the plugin file>
|
||||||
|
--tool-call-parser example \
|
||||||
|
--chat-template <your chat template> \
|
||||||
|
```
|
||||||
|
|
||||||
|
60
examples/tool_chat_template_internlm2_tool.jinja
Normal file
60
examples/tool_chat_template_internlm2_tool.jinja
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
{%- if messages[0]["role"] == "system" %}
|
||||||
|
{%- set system_message = messages[0]["content"] %}
|
||||||
|
{%- set loop_messages = messages[1:] %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set loop_messages = messages %}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{%- if not tools is defined %}
|
||||||
|
{%- set tools = none %}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{{- bos_token }}
|
||||||
|
{%- if system_message is defined %}
|
||||||
|
{{- "<|im_start|>system\n" + system_message + "<|im_end|>\n" }}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{%- if tools is not none %}
|
||||||
|
{{- "<|im_start|>system name=<|plugin|>\n[" }}
|
||||||
|
{%- for tool in tools %}
|
||||||
|
{{- tool.function|tojson }}
|
||||||
|
{%- if not loop.last %}
|
||||||
|
{{- ", " }}
|
||||||
|
{%- else %}
|
||||||
|
{{- "]" }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- "<|im_end|>\n" }}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{%- for message in loop_messages %}
|
||||||
|
{%- if message["role"] == "user" %}
|
||||||
|
{{- "<|im_start|>user\n" + message["content"] + "<|im_end|>\n"}}
|
||||||
|
{%- elif message.tool_calls is defined and message.tool_calls is not none %}
|
||||||
|
{%- set content = message["content"] if message["content"] else "" %}
|
||||||
|
{{- "<|im_start|>assistant\n" + content }}
|
||||||
|
{%- for tool_call in message.tool_calls %}
|
||||||
|
{%- set function=tool_call.function %}
|
||||||
|
{{- "<|action_start|><|plugin|>\n" }}
|
||||||
|
{{- '{"name": "' + function.name + '", '}}
|
||||||
|
{{- '"arguments": ' + function.arguments|tojson + '}' }}
|
||||||
|
{{- "<|action_end|>" }}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- "<|im_end|>\n" }}
|
||||||
|
{%- elif message["role"] == "assistant" %}
|
||||||
|
{{- "<|im_start|>assistant\n" + message["content"] + "<|im_end|>\n"}}
|
||||||
|
{%- elif message["role"] == "tool_results" or message["role"] == "tool" or message["role"] == "function" %}
|
||||||
|
{%- if message.content is defined and message.content.content is defined %}
|
||||||
|
{%- set content = message.content.content %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set content = message.content %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- "<|im_start|>environment name=<|plugin|>\n" + content|string + "<|im_end|>\n" }}
|
||||||
|
{%- else %}
|
||||||
|
{{- raise_exception("Only user and assistant and tool_results and tool and function roles are supported, with the exception of an initial optional system message!") }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|im_start|>assistant\n' }}
|
||||||
|
{%- endif %}
|
@ -87,6 +87,18 @@ CONFIGS: Dict[str, ServerConfig] = {
|
|||||||
"call the tool. Otherwise, answer the user's query directly "
|
"call the tool. Otherwise, answer the user's query directly "
|
||||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||||
"to the user's question - just respond to it normally."
|
"to the user's question - just respond to it normally."
|
||||||
|
},
|
||||||
|
"internlm": {
|
||||||
|
"model":
|
||||||
|
"internlm/internlm2_5-7b-chat",
|
||||||
|
"arguments": [
|
||||||
|
"--tool-call-parser", "internlm", "--chat-template",
|
||||||
|
str(VLLM_PATH /
|
||||||
|
"examples/tool_chat_template_internlm2_tool.jinja"),
|
||||||
|
"--trust_remote_code"
|
||||||
|
],
|
||||||
|
"supports_parallel":
|
||||||
|
False,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,7 +121,7 @@ WEATHER_TOOL: ChatCompletionToolParam = {
|
|||||||
"type":
|
"type":
|
||||||
"string",
|
"string",
|
||||||
"description":
|
"description":
|
||||||
"the two-letter abbreviation for the state "
|
"must the two-letter abbreviation for the state "
|
||||||
"that the city is in, e.g. 'CA' which would "
|
"that the city is in, e.g. 'CA' which would "
|
||||||
"mean 'California'"
|
"mean 'California'"
|
||||||
},
|
},
|
||||||
|
@ -53,6 +53,7 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
|||||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||||
from vllm.entrypoints.openai.serving_tokenization import (
|
from vllm.entrypoints.openai.serving_tokenization import (
|
||||||
OpenAIServingTokenization)
|
OpenAIServingTokenization)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
|
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
|
||||||
@ -526,6 +527,15 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||||
logger.info("args: %s", args)
|
logger.info("args: %s", args)
|
||||||
|
|
||||||
|
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||||
|
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||||
|
|
||||||
|
valide_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||||
|
if args.enable_auto_tool_choice \
|
||||||
|
and args.tool_call_parser not in valide_tool_parses:
|
||||||
|
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||||
|
f"(chose from {{ {','.join(valide_tool_parses)} }})")
|
||||||
|
|
||||||
temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
temp_socket.bind(("", args.port))
|
temp_socket.bind(("", args.port))
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ from typing import List, Optional, Sequence, Union
|
|||||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
PromptAdapterPath)
|
PromptAdapterPath)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
@ -190,16 +191,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
"Enable auto tool choice for supported models. Use --tool-call-parser"
|
"Enable auto tool choice for supported models. Use --tool-call-parser"
|
||||||
"to specify which parser to use")
|
"to specify which parser to use")
|
||||||
|
|
||||||
|
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tool-call-parser",
|
"--tool-call-parser",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["mistral", "hermes", "llama3_json"],
|
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
|
||||||
|
"--tool-parser-plugin",
|
||||||
default=None,
|
default=None,
|
||||||
help=
|
help=
|
||||||
"Select the tool call parser depending on the model that you're using."
|
"Select the tool call parser depending on the model that you're using."
|
||||||
" This is used to parse the model-generated tool call into OpenAI API "
|
" This is used to parse the model-generated tool call into OpenAI API "
|
||||||
"format. Required for --enable-auto-tool-choice.")
|
"format. Required for --enable-auto-tool-choice.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tool-parser-plugin",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help=
|
||||||
|
"Special the tool parser plugin write to parse the model-generated tool"
|
||||||
|
" into OpenAI API format, the name register in this plugin can be used "
|
||||||
|
"in --tool-call-parser.")
|
||||||
|
|
||||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||||
|
|
||||||
parser.add_argument('--max-log-len',
|
parser.add_argument('--max-log-len',
|
||||||
|
@ -29,10 +29,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
|||||||
OpenAIServing,
|
OpenAIServing,
|
||||||
PromptAdapterPath,
|
PromptAdapterPath,
|
||||||
TextTokensPrompt)
|
TextTokensPrompt)
|
||||||
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||||
Llama3JsonToolParser,
|
|
||||||
MistralToolParser,
|
|
||||||
ToolParser)
|
|
||||||
from vllm.inputs import TokensPrompt
|
from vllm.inputs import TokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
@ -82,15 +79,13 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
||||||
if self.enable_auto_tools:
|
if self.enable_auto_tools:
|
||||||
if tool_parser == "mistral":
|
try:
|
||||||
self.tool_parser = MistralToolParser
|
self.tool_parser = ToolParserManager.get_tool_parser(
|
||||||
elif tool_parser == "hermes":
|
tool_parser)
|
||||||
self.tool_parser = Hermes2ProToolParser
|
except Exception as e:
|
||||||
elif tool_parser == "llama3_json":
|
|
||||||
self.tool_parser = Llama3JsonToolParser
|
|
||||||
else:
|
|
||||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||||
"--tool-call-parser")
|
f"tool_parser:'{tool_parser}' which has not "
|
||||||
|
"been registered") from e
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self,
|
self,
|
||||||
@ -187,6 +182,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
raw_request.state.request_metadata = request_metadata
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if self.enable_auto_tools and self.tool_parser:
|
||||||
|
request = self.tool_parser(tokenizer).adjust_request(
|
||||||
|
request=request)
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
prompt_inputs = self._tokenize_prompt_input(
|
prompt_inputs = self._tokenize_prompt_input(
|
||||||
request,
|
request,
|
||||||
@ -282,11 +281,11 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
num_choices = 1 if request.n is None else request.n
|
num_choices = 1 if request.n is None else request.n
|
||||||
previous_num_tokens = [0] * num_choices
|
previous_num_tokens = [0] * num_choices
|
||||||
finish_reason_sent = [False] * num_choices
|
finish_reason_sent = [False] * num_choices
|
||||||
|
|
||||||
num_prompt_tokens = 0
|
num_prompt_tokens = 0
|
||||||
|
|
||||||
tool_parser: Optional[ToolParser] = self.tool_parser(
|
tool_parsers: List[Optional[ToolParser]] = [
|
||||||
tokenizer) if self.tool_parser else None
|
self.tool_parser(tokenizer) if self.tool_parser else None
|
||||||
|
] * num_choices
|
||||||
|
|
||||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||||
tool_choice_function_name = request.tool_choice.function.name
|
tool_choice_function_name = request.tool_choice.function.name
|
||||||
@ -324,7 +323,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
# NOTE num_choices defaults to 1 so this usually executes
|
# NOTE num_choices defaults to 1 so this usually executes
|
||||||
# once per request
|
# once per request
|
||||||
for i in range(num_choices):
|
for i in range(num_choices):
|
||||||
|
tool_parser = tool_parsers[i]
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=i,
|
index=i,
|
||||||
delta=DeltaMessage(
|
delta=DeltaMessage(
|
||||||
@ -399,6 +398,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
for output in res.outputs:
|
for output in res.outputs:
|
||||||
i = output.index
|
i = output.index
|
||||||
|
tool_parser = tool_parsers[i]
|
||||||
|
|
||||||
if finish_reason_sent[i]:
|
if finish_reason_sent[i]:
|
||||||
continue
|
continue
|
||||||
@ -446,7 +446,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
delta_text=delta_text,
|
delta_text=delta_text,
|
||||||
previous_token_ids=previous_token_ids,
|
previous_token_ids=previous_token_ids,
|
||||||
current_token_ids=current_token_ids,
|
current_token_ids=current_token_ids,
|
||||||
delta_token_ids=output.token_ids))
|
delta_token_ids=output.token_ids,
|
||||||
|
request=request))
|
||||||
|
|
||||||
# update the previous values for the next iteration
|
# update the previous values for the next iteration
|
||||||
previous_texts[i] = current_text
|
previous_texts[i] = current_text
|
||||||
@ -685,7 +686,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
and self.tool_parser:
|
and self.tool_parser:
|
||||||
|
|
||||||
tool_parser = self.tool_parser(tokenizer)
|
tool_parser = self.tool_parser(tokenizer)
|
||||||
tool_call_info = tool_parser.extract_tool_calls(output.text)
|
tool_call_info = tool_parser.extract_tool_calls(
|
||||||
|
output.text, request=request)
|
||||||
tools_called = tool_call_info.tools_called
|
tools_called = tool_call_info.tools_called
|
||||||
if tool_call_info.tools_called:
|
if tool_call_info.tools_called:
|
||||||
message = ChatMessage(role=role,
|
message = ChatMessage(role=role,
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
from .abstract_tool_parser import ToolParser
|
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||||
from .hermes_tool_parser import Hermes2ProToolParser
|
from .hermes_tool_parser import Hermes2ProToolParser
|
||||||
|
from .internlm2_tool_parser import Internlm2ToolParser
|
||||||
from .llama_tool_parser import Llama3JsonToolParser
|
from .llama_tool_parser import Llama3JsonToolParser
|
||||||
from .mistral_tool_parser import MistralToolParser
|
from .mistral_tool_parser import MistralToolParser
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ToolParser", "Hermes2ProToolParser", "MistralToolParser",
|
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
|
||||||
"Llama3JsonToolParser"
|
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
|
||||||
]
|
]
|
||||||
|
@ -1,9 +1,14 @@
|
|||||||
from typing import Dict, List, Sequence, Union
|
import importlib
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (DeltaMessage,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaMessage,
|
||||||
ExtractedToolCallInformation)
|
ExtractedToolCallInformation)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -24,8 +29,16 @@ class ToolParser:
|
|||||||
|
|
||||||
self.model_tokenizer = tokenizer
|
self.model_tokenizer = tokenizer
|
||||||
|
|
||||||
def extract_tool_calls(self,
|
def adjust_request(
|
||||||
model_output: str) -> ExtractedToolCallInformation:
|
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||||
|
"""
|
||||||
|
Static method that used to adjust the request parameters.
|
||||||
|
"""
|
||||||
|
return request
|
||||||
|
|
||||||
|
def extract_tool_calls(
|
||||||
|
self, model_output: str,
|
||||||
|
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||||
"""
|
"""
|
||||||
Static method that should be implemented for extracting tool calls from
|
Static method that should be implemented for extracting tool calls from
|
||||||
a complete model-generated string.
|
a complete model-generated string.
|
||||||
@ -44,6 +57,7 @@ class ToolParser:
|
|||||||
previous_token_ids: Sequence[int],
|
previous_token_ids: Sequence[int],
|
||||||
current_token_ids: Sequence[int],
|
current_token_ids: Sequence[int],
|
||||||
delta_token_ids: Sequence[int],
|
delta_token_ids: Sequence[int],
|
||||||
|
request: ChatCompletionRequest,
|
||||||
) -> Union[DeltaMessage, None]:
|
) -> Union[DeltaMessage, None]:
|
||||||
"""
|
"""
|
||||||
Instance method that should be implemented for extracting tool calls
|
Instance method that should be implemented for extracting tool calls
|
||||||
@ -55,3 +69,86 @@ class ToolParser:
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"AbstractToolParser.extract_tool_calls_streaming has not been "
|
"AbstractToolParser.extract_tool_calls_streaming has not been "
|
||||||
"implemented!")
|
"implemented!")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolParserManager:
|
||||||
|
tool_parsers: Dict[str, Type] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tool_parser(cls, name) -> Type:
|
||||||
|
"""
|
||||||
|
Get tool parser by name which is registered by `register_module`.
|
||||||
|
|
||||||
|
Raise a KeyError exception if the name is not registered.
|
||||||
|
"""
|
||||||
|
if name in cls.tool_parsers:
|
||||||
|
return cls.tool_parsers[name]
|
||||||
|
|
||||||
|
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _register_module(cls,
|
||||||
|
module: Type,
|
||||||
|
module_name: Optional[Union[str, List[str]]] = None,
|
||||||
|
force: bool = True) -> None:
|
||||||
|
if not issubclass(module, ToolParser):
|
||||||
|
raise TypeError(
|
||||||
|
f'module must be subclass of ToolParser, but got {type(module)}'
|
||||||
|
)
|
||||||
|
if module_name is None:
|
||||||
|
module_name = module.__name__
|
||||||
|
if isinstance(module_name, str):
|
||||||
|
module_name = [module_name]
|
||||||
|
for name in module_name:
|
||||||
|
if not force and name in cls.tool_parsers:
|
||||||
|
existed_module = cls.tool_parsers[name]
|
||||||
|
raise KeyError(f'{name} is already registered '
|
||||||
|
f'at {existed_module.__module__}')
|
||||||
|
cls.tool_parsers[name] = module
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_module(
|
||||||
|
cls,
|
||||||
|
name: Optional[Union[str, List[str]]] = None,
|
||||||
|
force: bool = True,
|
||||||
|
module: Union[Type, None] = None) -> Union[type, Callable]:
|
||||||
|
"""
|
||||||
|
Register module with the given name or name list. it can be used as a
|
||||||
|
decoder(with module as None) or normal function(with module as not
|
||||||
|
None).
|
||||||
|
"""
|
||||||
|
if not isinstance(force, bool):
|
||||||
|
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||||
|
|
||||||
|
# raise the error ahead of time
|
||||||
|
if not (name is None or isinstance(name, str)
|
||||||
|
or is_list_of(name, str)):
|
||||||
|
raise TypeError(
|
||||||
|
'name must be None, an instance of str, or a sequence of str, '
|
||||||
|
f'but got {type(name)}')
|
||||||
|
|
||||||
|
# use it as a normal method: x.register_module(module=SomeClass)
|
||||||
|
if module is not None:
|
||||||
|
cls._register_module(module=module, module_name=name, force=force)
|
||||||
|
return module
|
||||||
|
|
||||||
|
# use it as a decorator: @x.register_module()
|
||||||
|
def _register(module):
|
||||||
|
cls._register_module(module=module, module_name=name, force=force)
|
||||||
|
return module
|
||||||
|
|
||||||
|
return _register
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def import_tool_parser(cls, plugin_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Import a user defined tool parser by the path of the tool parser define
|
||||||
|
file.
|
||||||
|
"""
|
||||||
|
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
|
||||||
|
if spec is None or spec.loader is None:
|
||||||
|
logger.error("load %s from %s failed.", module_name, plugin_path)
|
||||||
|
return
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
@ -5,12 +5,13 @@ from typing import Dict, List, Sequence, Union
|
|||||||
import partial_json_parser
|
import partial_json_parser
|
||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
ExtractedToolCallInformation,
|
ExtractedToolCallInformation,
|
||||||
FunctionCall, ToolCall)
|
FunctionCall, ToolCall)
|
||||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||||
ToolParser)
|
ToolParser, ToolParserManager)
|
||||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||||
extract_intermediate_diff)
|
extract_intermediate_diff)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -20,6 +21,7 @@ from vllm.utils import random_uuid
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ToolParserManager.register_module("hermes")
|
||||||
class Hermes2ProToolParser(ToolParser):
|
class Hermes2ProToolParser(ToolParser):
|
||||||
|
|
||||||
def __init__(self, tokenizer: AnyTokenizer):
|
def __init__(self, tokenizer: AnyTokenizer):
|
||||||
@ -57,8 +59,11 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
"Hermes 2 Pro Tool parser could not locate tool call start/end "
|
"Hermes 2 Pro Tool parser could not locate tool call start/end "
|
||||||
"tokens in the tokenizer!")
|
"tokens in the tokenizer!")
|
||||||
|
|
||||||
def extract_tool_calls(self,
|
def extract_tool_calls(
|
||||||
model_output: str) -> ExtractedToolCallInformation:
|
self,
|
||||||
|
model_output: str,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> ExtractedToolCallInformation:
|
||||||
|
|
||||||
# sanity check; avoid unnecessary processing
|
# sanity check; avoid unnecessary processing
|
||||||
if self.tool_call_start_token not in model_output:
|
if self.tool_call_start_token not in model_output:
|
||||||
@ -114,6 +119,7 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
previous_token_ids: Sequence[int],
|
previous_token_ids: Sequence[int],
|
||||||
current_token_ids: Sequence[int],
|
current_token_ids: Sequence[int],
|
||||||
delta_token_ids: Sequence[int],
|
delta_token_ids: Sequence[int],
|
||||||
|
request: ChatCompletionRequest,
|
||||||
) -> Union[DeltaMessage, None]:
|
) -> Union[DeltaMessage, None]:
|
||||||
|
|
||||||
logger.debug("delta_text: %s", delta_text)
|
logger.debug("delta_text: %s", delta_text)
|
||||||
|
208
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
Normal file
208
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
import json
|
||||||
|
from typing import Dict, Sequence, Union
|
||||||
|
|
||||||
|
import partial_json_parser
|
||||||
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaFunctionCall, DeltaMessage,
|
||||||
|
DeltaToolCall,
|
||||||
|
ExtractedToolCallInformation,
|
||||||
|
FunctionCall, ToolCall)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||||
|
ToolParser, ToolParserManager)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||||
|
extract_intermediate_diff)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ToolParserManager.register_module(["internlm"])
|
||||||
|
class Internlm2ToolParser(ToolParser):
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: AnyTokenizer):
|
||||||
|
super().__init__(tokenizer)
|
||||||
|
self.position = 0
|
||||||
|
|
||||||
|
def adjust_request(
|
||||||
|
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||||
|
if request.tools and request.tool_choice != 'none':
|
||||||
|
# do not skip special tokens because internlm use the special
|
||||||
|
# tokens to indicated the start and end of the tool calls
|
||||||
|
# information.
|
||||||
|
request.skip_special_tokens = False
|
||||||
|
return request
|
||||||
|
|
||||||
|
def get_argments(self, obj):
|
||||||
|
if "parameters" in obj:
|
||||||
|
return obj.get("parameters")
|
||||||
|
elif "arguments" in obj:
|
||||||
|
return obj.get("arguments")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_tool_calls_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Union[DeltaMessage, None]:
|
||||||
|
if '<|action_start|>' not in current_text:
|
||||||
|
self.position = len(current_text)
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
# if the tool call is sended, return a empty delta message
|
||||||
|
# to make sure the finish_reason will be send correctly.
|
||||||
|
if self.current_tool_id > 0:
|
||||||
|
return DeltaMessage(content='')
|
||||||
|
|
||||||
|
last_pos = self.position
|
||||||
|
if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
new_delta = current_text[last_pos:]
|
||||||
|
text, action = new_delta.split('<|action_start|><|plugin|>')
|
||||||
|
|
||||||
|
if len(text) > 0:
|
||||||
|
self.position = self.position + len(text)
|
||||||
|
return DeltaMessage(content=text)
|
||||||
|
|
||||||
|
action = action.strip()
|
||||||
|
action = action.split('<|action_end|>'.strip())[0]
|
||||||
|
|
||||||
|
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||||
|
# sent yet, don't allow sending
|
||||||
|
# an incomplete string since OpenAI only ever (as far as I have
|
||||||
|
# seen) allows sending the entire tool/ function name at once.
|
||||||
|
flags = Allow.ALL if self.current_tool_name_sent \
|
||||||
|
else Allow.ALL & ~Allow.STR
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsable_arr = action
|
||||||
|
|
||||||
|
# tool calls are generated in an object in inernlm2
|
||||||
|
# it's not support parallel tool calls
|
||||||
|
try:
|
||||||
|
tool_call_arr: Dict = partial_json_parser.loads(
|
||||||
|
parsable_arr, flags)
|
||||||
|
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||||
|
logger.debug('not enough tokens to parse into JSON yet')
|
||||||
|
return None
|
||||||
|
|
||||||
|
# if the current tool name hasn't been sent, send if available
|
||||||
|
# - otherwise send nothing
|
||||||
|
if not self.current_tool_name_sent:
|
||||||
|
function_name = tool_call_arr.get("name")
|
||||||
|
if function_name:
|
||||||
|
self.current_tool_id = self.current_tool_id + 1
|
||||||
|
delta = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
|
type="function",
|
||||||
|
id=f"chatcmpl-tool-{random_uuid()}",
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
name=function_name).model_dump(
|
||||||
|
exclude_none=True))
|
||||||
|
])
|
||||||
|
self.current_tool_name_sent = True
|
||||||
|
self.streamed_args_for_tool.append("")
|
||||||
|
else:
|
||||||
|
delta = None
|
||||||
|
# now we know we're on the same tool call and we're streaming
|
||||||
|
# arguments
|
||||||
|
else:
|
||||||
|
prev_arguments = self.get_argments(
|
||||||
|
self.prev_tool_call_arr[self.current_tool_id])
|
||||||
|
cur_arguments = self.get_argments(tool_call_arr)
|
||||||
|
|
||||||
|
# not arguments generated
|
||||||
|
if not cur_arguments and not prev_arguments:
|
||||||
|
delta = None
|
||||||
|
# will never happen
|
||||||
|
elif not cur_arguments and prev_arguments:
|
||||||
|
logger.error(
|
||||||
|
"INVARIANT - impossible to have arguments reset "
|
||||||
|
"mid-arguments")
|
||||||
|
delta = None
|
||||||
|
# first time to get parameters
|
||||||
|
elif cur_arguments and not prev_arguments:
|
||||||
|
cur_arguments_json = json.dumps(cur_arguments)
|
||||||
|
|
||||||
|
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||||
|
index(delta_text) +
|
||||||
|
len(delta_text)]
|
||||||
|
delta = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
arguments=arguments_delta).
|
||||||
|
model_dump(exclude_none=True))
|
||||||
|
])
|
||||||
|
self.streamed_args_for_tool[
|
||||||
|
self.current_tool_id] += arguments_delta
|
||||||
|
# both prev and cur parameters, send the increase parameters
|
||||||
|
elif cur_arguments and prev_arguments:
|
||||||
|
cur_args_json = json.dumps(cur_arguments)
|
||||||
|
prev_args_json = json.dumps(prev_arguments)
|
||||||
|
|
||||||
|
argument_diff = extract_intermediate_diff(
|
||||||
|
cur_args_json, prev_args_json)
|
||||||
|
|
||||||
|
delta = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
arguments=argument_diff).model_dump(
|
||||||
|
exclude_none=True))
|
||||||
|
])
|
||||||
|
self.streamed_args_for_tool[
|
||||||
|
self.current_tool_id] += argument_diff
|
||||||
|
|
||||||
|
# check to see if the name is defined and has been sent. if so,
|
||||||
|
# stream the name - otherwise keep waiting
|
||||||
|
# finish by setting old and returning None as base case
|
||||||
|
tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
|
||||||
|
self.prev_tool_call_arr = [tool_call_arr]
|
||||||
|
return delta
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||||
|
logger.debug(
|
||||||
|
"Skipping chunk as a result of tool streaming extraction "
|
||||||
|
"error")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_tool_calls(
|
||||||
|
self,
|
||||||
|
model_output: str,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> ExtractedToolCallInformation:
|
||||||
|
text = model_output
|
||||||
|
tools = request.tools
|
||||||
|
if '<|action_start|><|plugin|>' in text:
|
||||||
|
text, action = text.split('<|action_start|><|plugin|>')
|
||||||
|
action = action.split('<|action_end|>'.strip())[0]
|
||||||
|
action = action[action.find('{'):]
|
||||||
|
action_dict = json.loads(action)
|
||||||
|
name, parameters = action_dict['name'], json.dumps(
|
||||||
|
action_dict.get('parameters', action_dict.get('arguments',
|
||||||
|
{})))
|
||||||
|
|
||||||
|
if not tools or name not in [t.function.name for t in tools]:
|
||||||
|
ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=text)
|
||||||
|
|
||||||
|
tool_calls = [
|
||||||
|
ToolCall(
|
||||||
|
function=FunctionCall(name=name, arguments=parameters))
|
||||||
|
]
|
||||||
|
return ExtractedToolCallInformation(
|
||||||
|
tools_called=True,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
content=text if len(text) > 0 else None)
|
||||||
|
|
||||||
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=text)
|
@ -7,12 +7,13 @@ import partial_json_parser
|
|||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
ExtractedToolCallInformation,
|
ExtractedToolCallInformation,
|
||||||
FunctionCall, ToolCall)
|
FunctionCall, ToolCall)
|
||||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||||
ToolParser)
|
ToolParser, ToolParserManager)
|
||||||
from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix
|
from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
@ -41,6 +42,7 @@ def is_complete_json(input_str):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@ToolParserManager.register_module("llama3_json")
|
||||||
class Llama3JsonToolParser(ToolParser):
|
class Llama3JsonToolParser(ToolParser):
|
||||||
"""
|
"""
|
||||||
Tool call parser for Llama 3.1 models intended for use with the
|
Tool call parser for Llama 3.1 models intended for use with the
|
||||||
@ -64,8 +66,9 @@ class Llama3JsonToolParser(ToolParser):
|
|||||||
add_special_tokens=False)[0]
|
add_special_tokens=False)[0]
|
||||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||||
|
|
||||||
def extract_tool_calls(self,
|
def extract_tool_calls(
|
||||||
model_output: str) -> ExtractedToolCallInformation:
|
self, model_output: str,
|
||||||
|
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||||
"""
|
"""
|
||||||
Extract the tool calls from a complete model response.
|
Extract the tool calls from a complete model response.
|
||||||
"""
|
"""
|
||||||
@ -125,6 +128,7 @@ class Llama3JsonToolParser(ToolParser):
|
|||||||
previous_token_ids: Sequence[int],
|
previous_token_ids: Sequence[int],
|
||||||
current_token_ids: Sequence[int],
|
current_token_ids: Sequence[int],
|
||||||
delta_token_ids: Sequence[int],
|
delta_token_ids: Sequence[int],
|
||||||
|
request: ChatCompletionRequest,
|
||||||
) -> Union[DeltaMessage, None]:
|
) -> Union[DeltaMessage, None]:
|
||||||
|
|
||||||
if not (current_text.startswith(self.bot_token)
|
if not (current_text.startswith(self.bot_token)
|
||||||
|
@ -8,12 +8,13 @@ import partial_json_parser
|
|||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
ExtractedToolCallInformation,
|
ExtractedToolCallInformation,
|
||||||
FunctionCall, ToolCall)
|
FunctionCall, ToolCall)
|
||||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||||
ToolParser)
|
ToolParser, ToolParserManager)
|
||||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||||
extract_intermediate_diff)
|
extract_intermediate_diff)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -36,6 +37,7 @@ class MistralToolCall(ToolCall):
|
|||||||
return "".join(choices(ALPHANUMERIC, k=9))
|
return "".join(choices(ALPHANUMERIC, k=9))
|
||||||
|
|
||||||
|
|
||||||
|
@ToolParserManager.register_module("mistral")
|
||||||
class MistralToolParser(ToolParser):
|
class MistralToolParser(ToolParser):
|
||||||
"""
|
"""
|
||||||
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
|
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
|
||||||
@ -47,9 +49,7 @@ class MistralToolParser(ToolParser):
|
|||||||
def __init__(self, tokenizer: AnyTokenizer):
|
def __init__(self, tokenizer: AnyTokenizer):
|
||||||
super().__init__(tokenizer)
|
super().__init__(tokenizer)
|
||||||
|
|
||||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
if not isinstance(self.model_tokenizer, MistralTokenizer):
|
||||||
self.model_tokenizer = self.model_tokenizer.tokenizer
|
|
||||||
else:
|
|
||||||
logger.info("Non-Mistral tokenizer detected when using a Mistral "
|
logger.info("Non-Mistral tokenizer detected when using a Mistral "
|
||||||
"model...")
|
"model...")
|
||||||
|
|
||||||
@ -61,11 +61,14 @@ class MistralToolParser(ToolParser):
|
|||||||
self.streamed_args_for_tool: List[str] = [
|
self.streamed_args_for_tool: List[str] = [
|
||||||
] # map what has been streamed for each tool so far to a list
|
] # map what has been streamed for each tool so far to a list
|
||||||
self.bot_token = "[TOOL_CALLS]"
|
self.bot_token = "[TOOL_CALLS]"
|
||||||
self.bot_token_id = self.model_tokenizer.vocab[self.bot_token]
|
self.bot_token_id = self.model_tokenizer.get_vocab()[self.bot_token]
|
||||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||||
|
|
||||||
def extract_tool_calls(self,
|
def extract_tool_calls(
|
||||||
model_output: str) -> ExtractedToolCallInformation:
|
self,
|
||||||
|
model_output: str,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> ExtractedToolCallInformation:
|
||||||
"""
|
"""
|
||||||
Extract the tool calls from a complete model response. Requires
|
Extract the tool calls from a complete model response. Requires
|
||||||
find-and-replacing single quotes with double quotes for JSON parsing,
|
find-and-replacing single quotes with double quotes for JSON parsing,
|
||||||
@ -119,6 +122,7 @@ class MistralToolParser(ToolParser):
|
|||||||
previous_token_ids: Sequence[int],
|
previous_token_ids: Sequence[int],
|
||||||
current_token_ids: Sequence[int],
|
current_token_ids: Sequence[int],
|
||||||
delta_token_ids: Sequence[int],
|
delta_token_ids: Sequence[int],
|
||||||
|
request: ChatCompletionRequest,
|
||||||
) -> Union[DeltaMessage, None]:
|
) -> Union[DeltaMessage, None]:
|
||||||
|
|
||||||
# if the tool call token is not in the tokens generated so far, append
|
# if the tool call token is not in the tokens generated so far, append
|
||||||
|
Loading…
x
Reference in New Issue
Block a user