[Frontend][Feature] support tool calling for internlm/internlm2_5-7b-chat model (#8405)
This commit is contained in:
parent
2838d6b38e
commit
3dbb215b38
@ -13,3 +13,4 @@ py-cpuinfo
|
||||
transformers
|
||||
mistral_common >= 1.3.4
|
||||
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
@ -157,8 +157,9 @@ vLLM will use guided decoding to ensure the response matches the tool parameter
|
||||
To enable this feature, you should set the following flags:
|
||||
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
|
||||
deems appropriate.
|
||||
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes`, `mistral` or `llama3_json`. Additional tool parsers
|
||||
will continue to be added in the future.
|
||||
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral` or `llama3_json` or `internlm`. Additional tool parsers
|
||||
will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`.
|
||||
* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`.
|
||||
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
|
||||
that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their
|
||||
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
|
||||
@ -218,4 +219,73 @@ it works better with vLLM.
|
||||
|
||||
Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja`
|
||||
|
||||
#### Internlm Models
|
||||
Supported models:
|
||||
* `internlm/internlm2_5-7b-chat` (confirmed)
|
||||
* Additional internlm2.5 function-calling models are compatible as well
|
||||
|
||||
Known issues:
|
||||
* Although this implementation also supports Internlm2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model.
|
||||
|
||||
Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja`
|
||||
|
||||
|
||||
### How to write a tool parser plugin
|
||||
|
||||
A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py.
|
||||
|
||||
Here is a summary of a plugin file:
|
||||
|
||||
```python
|
||||
|
||||
# import the required packages
|
||||
|
||||
# define a tool parser and register it to vllm
|
||||
# the name list in register_module can be used
|
||||
# in --tool-call-parser. you can define as many
|
||||
# tool parsers as you want here.
|
||||
@ToolParserManager.register_module(["example"])
|
||||
class ExampleToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# adjust request. e.g.: set skip special tokens
|
||||
# to False for tool call output.
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
return request
|
||||
|
||||
# implement the tool call parse for stream call
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
return delta
|
||||
|
||||
# implement the tool parse for non-stream call
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=text)
|
||||
|
||||
|
||||
```
|
||||
|
||||
Then you can use this plugin in the command line like this.
|
||||
```
|
||||
--enable-auto-tool-choice \
|
||||
--tool-parser-plugin <absolute path of the plugin file>
|
||||
--tool-call-parser example \
|
||||
--chat-template <your chat template> \
|
||||
```
|
||||
|
||||
|
60
examples/tool_chat_template_internlm2_tool.jinja
Normal file
60
examples/tool_chat_template_internlm2_tool.jinja
Normal file
@ -0,0 +1,60 @@
|
||||
{%- if messages[0]["role"] == "system" %}
|
||||
{%- set system_message = messages[0]["content"] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
|
||||
{%- if not tools is defined %}
|
||||
{%- set tools = none %}
|
||||
{%- endif %}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- if system_message is defined %}
|
||||
{{- "<|im_start|>system\n" + system_message + "<|im_end|>\n" }}
|
||||
{%- endif %}
|
||||
|
||||
{%- if tools is not none %}
|
||||
{{- "<|im_start|>system name=<|plugin|>\n[" }}
|
||||
{%- for tool in tools %}
|
||||
{{- tool.function|tojson }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- else %}
|
||||
{{- "]" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "<|im_end|>\n" }}
|
||||
{%- endif %}
|
||||
|
||||
{%- for message in loop_messages %}
|
||||
{%- if message["role"] == "user" %}
|
||||
{{- "<|im_start|>user\n" + message["content"] + "<|im_end|>\n"}}
|
||||
{%- elif message.tool_calls is defined and message.tool_calls is not none %}
|
||||
{%- set content = message["content"] if message["content"] else "" %}
|
||||
{{- "<|im_start|>assistant\n" + content }}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- set function=tool_call.function %}
|
||||
{{- "<|action_start|><|plugin|>\n" }}
|
||||
{{- '{"name": "' + function.name + '", '}}
|
||||
{{- '"arguments": ' + function.arguments|tojson + '}' }}
|
||||
{{- "<|action_end|>" }}
|
||||
{%- endfor %}
|
||||
{{- "<|im_end|>\n" }}
|
||||
{%- elif message["role"] == "assistant" %}
|
||||
{{- "<|im_start|>assistant\n" + message["content"] + "<|im_end|>\n"}}
|
||||
{%- elif message["role"] == "tool_results" or message["role"] == "tool" or message["role"] == "function" %}
|
||||
{%- if message.content is defined and message.content.content is defined %}
|
||||
{%- set content = message.content.content %}
|
||||
{%- else %}
|
||||
{%- set content = message.content %}
|
||||
{%- endif %}
|
||||
{{- "<|im_start|>environment name=<|plugin|>\n" + content|string + "<|im_end|>\n" }}
|
||||
{%- else %}
|
||||
{{- raise_exception("Only user and assistant and tool_results and tool and function roles are supported, with the exception of an initial optional system message!") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
@ -87,6 +87,18 @@ CONFIGS: Dict[str, ServerConfig] = {
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally."
|
||||
},
|
||||
"internlm": {
|
||||
"model":
|
||||
"internlm/internlm2_5-7b-chat",
|
||||
"arguments": [
|
||||
"--tool-call-parser", "internlm", "--chat-template",
|
||||
str(VLLM_PATH /
|
||||
"examples/tool_chat_template_internlm2_tool.jinja"),
|
||||
"--trust_remote_code"
|
||||
],
|
||||
"supports_parallel":
|
||||
False,
|
||||
}
|
||||
}
|
||||
|
||||
@ -109,7 +121,7 @@ WEATHER_TOOL: ChatCompletionToolParam = {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"the two-letter abbreviation for the state "
|
||||
"must the two-letter abbreviation for the state "
|
||||
"that the city is in, e.g. 'CA' which would "
|
||||
"mean 'California'"
|
||||
},
|
||||
|
@ -53,6 +53,7 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
|
||||
@ -526,6 +527,15 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
valide_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||
if args.enable_auto_tool_choice \
|
||||
and args.tool_call_parser not in valide_tool_parses:
|
||||
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||
f"(chose from {{ {','.join(valide_tool_parses)} }})")
|
||||
|
||||
temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
temp_socket.bind(("", args.port))
|
||||
|
||||
|
@ -12,6 +12,7 @@ from typing import List, Optional, Sequence, Union
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
PromptAdapterPath)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -190,16 +191,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"Enable auto tool choice for supported models. Use --tool-call-parser"
|
||||
"to specify which parser to use")
|
||||
|
||||
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
choices=["mistral", "hermes", "llama3_json"],
|
||||
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
|
||||
"--tool-parser-plugin",
|
||||
default=None,
|
||||
help=
|
||||
"Select the tool call parser depending on the model that you're using."
|
||||
" This is used to parse the model-generated tool call into OpenAI API "
|
||||
"format. Required for --enable-auto-tool-choice.")
|
||||
|
||||
parser.add_argument(
|
||||
"--tool-parser-plugin",
|
||||
type=str,
|
||||
default="",
|
||||
help=
|
||||
"Special the tool parser plugin write to parse the model-generated tool"
|
||||
" into OpenAI API format, the name register in this plugin can be used "
|
||||
"in --tool-call-parser.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
|
@ -29,10 +29,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath,
|
||||
TextTokensPrompt)
|
||||
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
|
||||
Llama3JsonToolParser,
|
||||
MistralToolParser,
|
||||
ToolParser)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
@ -82,15 +79,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
||||
if self.enable_auto_tools:
|
||||
if tool_parser == "mistral":
|
||||
self.tool_parser = MistralToolParser
|
||||
elif tool_parser == "hermes":
|
||||
self.tool_parser = Hermes2ProToolParser
|
||||
elif tool_parser == "llama3_json":
|
||||
self.tool_parser = Llama3JsonToolParser
|
||||
else:
|
||||
try:
|
||||
self.tool_parser = ToolParserManager.get_tool_parser(
|
||||
tool_parser)
|
||||
except Exception as e:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||
"--tool-call-parser")
|
||||
f"tool_parser:'{tool_parser}' which has not "
|
||||
"been registered") from e
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
@ -187,6 +182,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
if self.enable_auto_tools and self.tool_parser:
|
||||
request = self.tool_parser(tokenizer).adjust_request(
|
||||
request=request)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt_inputs = self._tokenize_prompt_input(
|
||||
request,
|
||||
@ -282,11 +281,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_num_tokens = [0] * num_choices
|
||||
finish_reason_sent = [False] * num_choices
|
||||
|
||||
num_prompt_tokens = 0
|
||||
|
||||
tool_parser: Optional[ToolParser] = self.tool_parser(
|
||||
tokenizer) if self.tool_parser else None
|
||||
tool_parsers: List[Optional[ToolParser]] = [
|
||||
self.tool_parser(tokenizer) if self.tool_parser else None
|
||||
] * num_choices
|
||||
|
||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
tool_choice_function_name = request.tool_choice.function.name
|
||||
@ -324,7 +323,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# NOTE num_choices defaults to 1 so this usually executes
|
||||
# once per request
|
||||
for i in range(num_choices):
|
||||
|
||||
tool_parser = tool_parsers[i]
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
@ -399,6 +398,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
tool_parser = tool_parsers[i]
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
@ -446,7 +446,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=output.token_ids))
|
||||
delta_token_ids=output.token_ids,
|
||||
request=request))
|
||||
|
||||
# update the previous values for the next iteration
|
||||
previous_texts[i] = current_text
|
||||
@ -685,7 +686,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and self.tool_parser:
|
||||
|
||||
tool_parser = self.tool_parser(tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(output.text)
|
||||
tool_call_info = tool_parser.extract_tool_calls(
|
||||
output.text, request=request)
|
||||
tools_called = tool_call_info.tools_called
|
||||
if tool_call_info.tools_called:
|
||||
message = ChatMessage(role=role,
|
||||
|
@ -1,9 +1,10 @@
|
||||
from .abstract_tool_parser import ToolParser
|
||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||
from .hermes_tool_parser import Hermes2ProToolParser
|
||||
from .internlm2_tool_parser import Internlm2ToolParser
|
||||
from .llama_tool_parser import Llama3JsonToolParser
|
||||
from .mistral_tool_parser import MistralToolParser
|
||||
|
||||
__all__ = [
|
||||
"ToolParser", "Hermes2ProToolParser", "MistralToolParser",
|
||||
"Llama3JsonToolParser"
|
||||
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
|
||||
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
|
||||
]
|
||||
|
@ -1,9 +1,14 @@
|
||||
from typing import Dict, List, Sequence, Union
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (DeltaMessage,
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -24,8 +29,16 @@ class ToolParser:
|
||||
|
||||
self.model_tokenizer = tokenizer
|
||||
|
||||
def extract_tool_calls(self,
|
||||
model_output: str) -> ExtractedToolCallInformation:
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
"""
|
||||
Static method that used to adjust the request parameters.
|
||||
"""
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Static method that should be implemented for extracting tool calls from
|
||||
a complete model-generated string.
|
||||
@ -44,6 +57,7 @@ class ToolParser:
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Instance method that should be implemented for extracting tool calls
|
||||
@ -55,3 +69,86 @@ class ToolParser:
|
||||
raise NotImplementedError(
|
||||
"AbstractToolParser.extract_tool_calls_streaming has not been "
|
||||
"implemented!")
|
||||
|
||||
|
||||
class ToolParserManager:
|
||||
tool_parsers: Dict[str, Type] = {}
|
||||
|
||||
@classmethod
|
||||
def get_tool_parser(cls, name) -> Type:
|
||||
"""
|
||||
Get tool parser by name which is registered by `register_module`.
|
||||
|
||||
Raise a KeyError exception if the name is not registered.
|
||||
"""
|
||||
if name in cls.tool_parsers:
|
||||
return cls.tool_parsers[name]
|
||||
|
||||
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
|
||||
|
||||
@classmethod
|
||||
def _register_module(cls,
|
||||
module: Type,
|
||||
module_name: Optional[Union[str, List[str]]] = None,
|
||||
force: bool = True) -> None:
|
||||
if not issubclass(module, ToolParser):
|
||||
raise TypeError(
|
||||
f'module must be subclass of ToolParser, but got {type(module)}'
|
||||
)
|
||||
if module_name is None:
|
||||
module_name = module.__name__
|
||||
if isinstance(module_name, str):
|
||||
module_name = [module_name]
|
||||
for name in module_name:
|
||||
if not force and name in cls.tool_parsers:
|
||||
existed_module = cls.tool_parsers[name]
|
||||
raise KeyError(f'{name} is already registered '
|
||||
f'at {existed_module.__module__}')
|
||||
cls.tool_parsers[name] = module
|
||||
|
||||
@classmethod
|
||||
def register_module(
|
||||
cls,
|
||||
name: Optional[Union[str, List[str]]] = None,
|
||||
force: bool = True,
|
||||
module: Union[Type, None] = None) -> Union[type, Callable]:
|
||||
"""
|
||||
Register module with the given name or name list. it can be used as a
|
||||
decoder(with module as None) or normal function(with module as not
|
||||
None).
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str)
|
||||
or is_list_of(name, str)):
|
||||
raise TypeError(
|
||||
'name must be None, an instance of str, or a sequence of str, '
|
||||
f'but got {type(name)}')
|
||||
|
||||
# use it as a normal method: x.register_module(module=SomeClass)
|
||||
if module is not None:
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
# use it as a decorator: @x.register_module()
|
||||
def _register(module):
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
return _register
|
||||
|
||||
@classmethod
|
||||
def import_tool_parser(cls, plugin_path: str) -> None:
|
||||
"""
|
||||
Import a user defined tool parser by the path of the tool parser define
|
||||
file.
|
||||
"""
|
||||
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
||||
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
|
||||
if spec is None or spec.loader is None:
|
||||
logger.error("load %s from %s failed.", module_name, plugin_path)
|
||||
return
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
@ -5,12 +5,13 @@ from typing import Dict, List, Sequence, Union
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser)
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
@ -20,6 +21,7 @@ from vllm.utils import random_uuid
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("hermes")
|
||||
class Hermes2ProToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
@ -57,8 +59,11 @@ class Hermes2ProToolParser(ToolParser):
|
||||
"Hermes 2 Pro Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
def extract_tool_calls(self,
|
||||
model_output: str) -> ExtractedToolCallInformation:
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_call_start_token not in model_output:
|
||||
@ -114,6 +119,7 @@ class Hermes2ProToolParser(ToolParser):
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
|
208
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
Normal file
208
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
Normal file
@ -0,0 +1,208 @@
|
||||
import json
|
||||
from typing import Dict, Sequence, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module(["internlm"])
|
||||
class Internlm2ToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.position = 0
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != 'none':
|
||||
# do not skip special tokens because internlm use the special
|
||||
# tokens to indicated the start and end of the tool calls
|
||||
# information.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def get_argments(self, obj):
|
||||
if "parameters" in obj:
|
||||
return obj.get("parameters")
|
||||
elif "arguments" in obj:
|
||||
return obj.get("arguments")
|
||||
return None
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
if '<|action_start|>' not in current_text:
|
||||
self.position = len(current_text)
|
||||
return DeltaMessage(content=delta_text)
|
||||
# if the tool call is sended, return a empty delta message
|
||||
# to make sure the finish_reason will be send correctly.
|
||||
if self.current_tool_id > 0:
|
||||
return DeltaMessage(content='')
|
||||
|
||||
last_pos = self.position
|
||||
if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
|
||||
return None
|
||||
|
||||
new_delta = current_text[last_pos:]
|
||||
text, action = new_delta.split('<|action_start|><|plugin|>')
|
||||
|
||||
if len(text) > 0:
|
||||
self.position = self.position + len(text)
|
||||
return DeltaMessage(content=text)
|
||||
|
||||
action = action.strip()
|
||||
action = action.split('<|action_end|>'.strip())[0]
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
|
||||
try:
|
||||
parsable_arr = action
|
||||
|
||||
# tool calls are generated in an object in inernlm2
|
||||
# it's not support parallel tool calls
|
||||
try:
|
||||
tool_call_arr: Dict = partial_json_parser.loads(
|
||||
parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = tool_call_arr.get("name")
|
||||
if function_name:
|
||||
self.current_tool_id = self.current_tool_id + 1
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
self.streamed_args_for_tool.append("")
|
||||
else:
|
||||
delta = None
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
prev_arguments = self.get_argments(
|
||||
self.prev_tool_call_arr[self.current_tool_id])
|
||||
cur_arguments = self.get_argments(tool_call_arr)
|
||||
|
||||
# not arguments generated
|
||||
if not cur_arguments and not prev_arguments:
|
||||
delta = None
|
||||
# will never happen
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset "
|
||||
"mid-arguments")
|
||||
delta = None
|
||||
# first time to get parameters
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments)
|
||||
|
||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||
index(delta_text) +
|
||||
len(delta_text)]
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += arguments_delta
|
||||
# both prev and cur parameters, send the increase parameters
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
|
||||
self.prev_tool_call_arr = [tool_call_arr]
|
||||
return delta
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
text = model_output
|
||||
tools = request.tools
|
||||
if '<|action_start|><|plugin|>' in text:
|
||||
text, action = text.split('<|action_start|><|plugin|>')
|
||||
action = action.split('<|action_end|>'.strip())[0]
|
||||
action = action[action.find('{'):]
|
||||
action_dict = json.loads(action)
|
||||
name, parameters = action_dict['name'], json.dumps(
|
||||
action_dict.get('parameters', action_dict.get('arguments',
|
||||
{})))
|
||||
|
||||
if not tools or name not in [t.function.name for t in tools]:
|
||||
ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=text)
|
||||
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
function=FunctionCall(name=name, arguments=parameters))
|
||||
]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=text if len(text) > 0 else None)
|
||||
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=text)
|
@ -7,12 +7,13 @@ import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser)
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
@ -41,6 +42,7 @@ def is_complete_json(input_str):
|
||||
return False
|
||||
|
||||
|
||||
@ToolParserManager.register_module("llama3_json")
|
||||
class Llama3JsonToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Llama 3.1 models intended for use with the
|
||||
@ -64,8 +66,9 @@ class Llama3JsonToolParser(ToolParser):
|
||||
add_special_tokens=False)[0]
|
||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||
|
||||
def extract_tool_calls(self,
|
||||
model_output: str) -> ExtractedToolCallInformation:
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
@ -125,6 +128,7 @@ class Llama3JsonToolParser(ToolParser):
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if not (current_text.startswith(self.bot_token)
|
||||
|
@ -8,12 +8,13 @@ import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser)
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
@ -36,6 +37,7 @@ class MistralToolCall(ToolCall):
|
||||
return "".join(choices(ALPHANUMERIC, k=9))
|
||||
|
||||
|
||||
@ToolParserManager.register_module("mistral")
|
||||
class MistralToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
|
||||
@ -47,9 +49,7 @@ class MistralToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
self.model_tokenizer = self.model_tokenizer.tokenizer
|
||||
else:
|
||||
if not isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
logger.info("Non-Mistral tokenizer detected when using a Mistral "
|
||||
"model...")
|
||||
|
||||
@ -61,11 +61,14 @@ class MistralToolParser(ToolParser):
|
||||
self.streamed_args_for_tool: List[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.model_tokenizer.vocab[self.bot_token]
|
||||
self.bot_token_id = self.model_tokenizer.get_vocab()[self.bot_token]
|
||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||
|
||||
def extract_tool_calls(self,
|
||||
model_output: str) -> ExtractedToolCallInformation:
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response. Requires
|
||||
find-and-replacing single quotes with double quotes for JSON parsing,
|
||||
@ -119,6 +122,7 @@ class MistralToolParser(ToolParser):
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
# if the tool call token is not in the tokens generated so far, append
|
||||
|
Loading…
x
Reference in New Issue
Block a user