[Frontend][Feature] support tool calling for internlm/internlm2_5-7b-chat model (#8405)

This commit is contained in:
代君 2024-10-04 10:36:39 +08:00 committed by GitHub
parent 2838d6b38e
commit 3dbb215b38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 533 additions and 46 deletions

View File

@ -12,4 +12,5 @@ torch
py-cpuinfo
transformers
mistral_common >= 1.3.4
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args

View File

@ -157,8 +157,9 @@ vLLM will use guided decoding to ensure the response matches the tool parameter
To enable this feature, you should set the following flags:
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
deems appropriate.
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes`, `mistral` or `llama3_json`. Additional tool parsers
will continue to be added in the future.
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral` or `llama3_json` or `internlm`. Additional tool parsers
will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`.
* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`.
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
@ -218,4 +219,73 @@ it works better with vLLM.
Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja`
#### Internlm Models
Supported models:
* `internlm/internlm2_5-7b-chat` (confirmed)
* Additional internlm2.5 function-calling models are compatible as well
Known issues:
* Although this implementation also supports Internlm2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model.
Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja`
### How to write a tool parser plugin
A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py.
Here is a summary of a plugin file:
```python
# import the required packages
# define a tool parser and register it to vllm
# the name list in register_module can be used
# in --tool-call-parser. you can define as many
# tool parsers as you want here.
@ToolParserManager.register_module(["example"])
class ExampleToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
# adjust request. e.g.: set skip special tokens
# to False for tool call output.
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
return request
# implement the tool call parse for stream call
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
return delta
# implement the tool parse for non-stream call
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=text)
```
Then you can use this plugin in the command line like this.
```
--enable-auto-tool-choice \
--tool-parser-plugin <absolute path of the plugin file>
--tool-call-parser example \
--chat-template <your chat template> \
```

View File

@ -0,0 +1,60 @@
{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{{- bos_token }}
{%- if system_message is defined %}
{{- "<|im_start|>system\n" + system_message + "<|im_end|>\n" }}
{%- endif %}
{%- if tools is not none %}
{{- "<|im_start|>system name=<|plugin|>\n[" }}
{%- for tool in tools %}
{{- tool.function|tojson }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{{- "<|im_end|>\n" }}
{%- endif %}
{%- for message in loop_messages %}
{%- if message["role"] == "user" %}
{{- "<|im_start|>user\n" + message["content"] + "<|im_end|>\n"}}
{%- elif message.tool_calls is defined and message.tool_calls is not none %}
{%- set content = message["content"] if message["content"] else "" %}
{{- "<|im_start|>assistant\n" + content }}
{%- for tool_call in message.tool_calls %}
{%- set function=tool_call.function %}
{{- "<|action_start|><|plugin|>\n" }}
{{- '{"name": "' + function.name + '", '}}
{{- '"arguments": ' + function.arguments|tojson + '}' }}
{{- "<|action_end|>" }}
{%- endfor %}
{{- "<|im_end|>\n" }}
{%- elif message["role"] == "assistant" %}
{{- "<|im_start|>assistant\n" + message["content"] + "<|im_end|>\n"}}
{%- elif message["role"] == "tool_results" or message["role"] == "tool" or message["role"] == "function" %}
{%- if message.content is defined and message.content.content is defined %}
{%- set content = message.content.content %}
{%- else %}
{%- set content = message.content %}
{%- endif %}
{{- "<|im_start|>environment name=<|plugin|>\n" + content|string + "<|im_end|>\n" }}
{%- else %}
{{- raise_exception("Only user and assistant and tool_results and tool and function roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}

View File

@ -87,6 +87,18 @@ CONFIGS: Dict[str, ServerConfig] = {
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
},
"internlm": {
"model":
"internlm/internlm2_5-7b-chat",
"arguments": [
"--tool-call-parser", "internlm", "--chat-template",
str(VLLM_PATH /
"examples/tool_chat_template_internlm2_tool.jinja"),
"--trust_remote_code"
],
"supports_parallel":
False,
}
}
@ -109,7 +121,7 @@ WEATHER_TOOL: ChatCompletionToolParam = {
"type":
"string",
"description":
"the two-letter abbreviation for the state "
"must the two-letter abbreviation for the state "
"that the city is in, e.g. 'CA' which would "
"mean 'California'"
},

View File

@ -53,6 +53,7 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
@ -526,6 +527,15 @@ async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
valide_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valide_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valide_tool_parses)} }})")
temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
temp_socket.bind(("", args.port))

View File

@ -12,6 +12,7 @@ from typing import List, Optional, Sequence, Union
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.utils import FlexibleArgumentParser
@ -190,16 +191,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"Enable auto tool choice for supported models. Use --tool-call-parser"
"to specify which parser to use")
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["mistral", "hermes", "llama3_json"],
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
"--tool-parser-plugin",
default=None,
help=
"Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice.")
parser.add_argument(
"--tool-parser-plugin",
type=str,
default="",
help=
"Special the tool parser plugin write to parse the model-generated tool"
" into OpenAI API format, the name register in this plugin can be used "
"in --tool-call-parser.")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',

View File

@ -29,10 +29,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
OpenAIServing,
PromptAdapterPath,
TextTokensPrompt)
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
Llama3JsonToolParser,
MistralToolParser,
ToolParser)
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
@ -82,15 +79,13 @@ class OpenAIServingChat(OpenAIServing):
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.enable_auto_tools:
if tool_parser == "mistral":
self.tool_parser = MistralToolParser
elif tool_parser == "hermes":
self.tool_parser = Hermes2ProToolParser
elif tool_parser == "llama3_json":
self.tool_parser = Llama3JsonToolParser
else:
try:
self.tool_parser = ToolParserManager.get_tool_parser(
tool_parser)
except Exception as e:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
f"tool_parser:'{tool_parser}' which has not "
"been registered") from e
async def create_chat_completion(
self,
@ -187,6 +182,10 @@ class OpenAIServingChat(OpenAIServing):
raw_request.state.request_metadata = request_metadata
try:
if self.enable_auto_tools and self.tool_parser:
request = self.tool_parser(tokenizer).adjust_request(
request=request)
if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
@ -282,11 +281,11 @@ class OpenAIServingChat(OpenAIServing):
num_choices = 1 if request.n is None else request.n
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0
tool_parser: Optional[ToolParser] = self.tool_parser(
tokenizer) if self.tool_parser else None
tool_parsers: List[Optional[ToolParser]] = [
self.tool_parser(tokenizer) if self.tool_parser else None
] * num_choices
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
@ -324,7 +323,7 @@ class OpenAIServingChat(OpenAIServing):
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for i in range(num_choices):
tool_parser = tool_parsers[i]
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
@ -399,6 +398,7 @@ class OpenAIServingChat(OpenAIServing):
for output in res.outputs:
i = output.index
tool_parser = tool_parsers[i]
if finish_reason_sent[i]:
continue
@ -446,7 +446,8 @@ class OpenAIServingChat(OpenAIServing):
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=output.token_ids))
delta_token_ids=output.token_ids,
request=request))
# update the previous values for the next iteration
previous_texts[i] = current_text
@ -685,7 +686,8 @@ class OpenAIServingChat(OpenAIServing):
and self.tool_parser:
tool_parser = self.tool_parser(tokenizer)
tool_call_info = tool_parser.extract_tool_calls(output.text)
tool_call_info = tool_parser.extract_tool_calls(
output.text, request=request)
tools_called = tool_call_info.tools_called
if tool_call_info.tools_called:
message = ChatMessage(role=role,

View File

@ -1,9 +1,10 @@
from .abstract_tool_parser import ToolParser
from .abstract_tool_parser import ToolParser, ToolParserManager
from .hermes_tool_parser import Hermes2ProToolParser
from .internlm2_tool_parser import Internlm2ToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser
__all__ = [
"ToolParser", "Hermes2ProToolParser", "MistralToolParser",
"Llama3JsonToolParser"
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
]

View File

@ -1,9 +1,14 @@
from typing import Dict, List, Sequence, Union
import importlib
import importlib.util
import os
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
from vllm.entrypoints.openai.protocol import (DeltaMessage,
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import is_list_of
logger = init_logger(__name__)
@ -24,8 +29,16 @@ class ToolParser:
self.model_tokenizer = tokenizer
def extract_tool_calls(self,
model_output: str) -> ExtractedToolCallInformation:
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
"""
Static method that used to adjust the request parameters.
"""
return request
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Static method that should be implemented for extracting tool calls from
a complete model-generated string.
@ -44,6 +57,7 @@ class ToolParser:
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
"""
Instance method that should be implemented for extracting tool calls
@ -55,3 +69,86 @@ class ToolParser:
raise NotImplementedError(
"AbstractToolParser.extract_tool_calls_streaming has not been "
"implemented!")
class ToolParserManager:
tool_parsers: Dict[str, Type] = {}
@classmethod
def get_tool_parser(cls, name) -> Type:
"""
Get tool parser by name which is registered by `register_module`.
Raise a KeyError exception if the name is not registered.
"""
if name in cls.tool_parsers:
return cls.tool_parsers[name]
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
@classmethod
def _register_module(cls,
module: Type,
module_name: Optional[Union[str, List[str]]] = None,
force: bool = True) -> None:
if not issubclass(module, ToolParser):
raise TypeError(
f'module must be subclass of ToolParser, but got {type(module)}'
)
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in cls.tool_parsers:
existed_module = cls.tool_parsers[name]
raise KeyError(f'{name} is already registered '
f'at {existed_module.__module__}')
cls.tool_parsers[name] = module
@classmethod
def register_module(
cls,
name: Optional[Union[str, List[str]]] = None,
force: bool = True,
module: Union[Type, None] = None) -> Union[type, Callable]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
None).
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# raise the error ahead of time
if not (name is None or isinstance(name, str)
or is_list_of(name, str)):
raise TypeError(
'name must be None, an instance of str, or a sequence of str, '
f'but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
cls._register_module(module=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(module):
cls._register_module(module=module, module_name=name, force=force)
return module
return _register
@classmethod
def import_tool_parser(cls, plugin_path: str) -> None:
"""
Import a user defined tool parser by the path of the tool parser define
file.
"""
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
if spec is None or spec.loader is None:
logger.error("load %s from %s failed.", module_name, plugin_path)
return
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

View File

@ -5,12 +5,13 @@ from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser)
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
@ -20,6 +21,7 @@ from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module("hermes")
class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
@ -57,8 +59,11 @@ class Hermes2ProToolParser(ToolParser):
"Hermes 2 Pro Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
def extract_tool_calls(self,
model_output: str) -> ExtractedToolCallInformation:
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_call_start_token not in model_output:
@ -114,6 +119,7 @@ class Hermes2ProToolParser(ToolParser):
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
logger.debug("delta_text: %s", delta_text)

View File

@ -0,0 +1,208 @@
import json
from typing import Dict, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module(["internlm"])
class Internlm2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.position = 0
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
# do not skip special tokens because internlm use the special
# tokens to indicated the start and end of the tool calls
# information.
request.skip_special_tokens = False
return request
def get_argments(self, obj):
if "parameters" in obj:
return obj.get("parameters")
elif "arguments" in obj:
return obj.get("arguments")
return None
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if '<|action_start|>' not in current_text:
self.position = len(current_text)
return DeltaMessage(content=delta_text)
# if the tool call is sended, return a empty delta message
# to make sure the finish_reason will be send correctly.
if self.current_tool_id > 0:
return DeltaMessage(content='')
last_pos = self.position
if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
return None
new_delta = current_text[last_pos:]
text, action = new_delta.split('<|action_start|><|plugin|>')
if len(text) > 0:
self.position = self.position + len(text)
return DeltaMessage(content=text)
action = action.strip()
action = action.split('<|action_end|>'.strip())[0]
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
parsable_arr = action
# tool calls are generated in an object in inernlm2
# it's not support parallel tool calls
try:
tool_call_arr: Dict = partial_json_parser.loads(
parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = tool_call_arr.get("name")
if function_name:
self.current_tool_id = self.current_tool_id + 1
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
self.streamed_args_for_tool.append("")
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.get_argments(
self.prev_tool_call_arr[self.current_tool_id])
cur_arguments = self.get_argments(tool_call_arr)
# not arguments generated
if not cur_arguments and not prev_arguments:
delta = None
# will never happen
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset "
"mid-arguments")
delta = None
# first time to get parameters
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
arguments_delta = cur_arguments_json[:cur_arguments_json.
index(delta_text) +
len(delta_text)]
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += arguments_delta
# both prev and cur parameters, send the increase parameters
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
self.prev_tool_call_arr = [tool_call_arr]
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
text = model_output
tools = request.tools
if '<|action_start|><|plugin|>' in text:
text, action = text.split('<|action_start|><|plugin|>')
action = action.split('<|action_end|>'.strip())[0]
action = action[action.find('{'):]
action_dict = json.loads(action)
name, parameters = action_dict['name'], json.dumps(
action_dict.get('parameters', action_dict.get('arguments',
{})))
if not tools or name not in [t.function.name for t in tools]:
ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=text)
tool_calls = [
ToolCall(
function=FunctionCall(name=name, arguments=parameters))
]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=text if len(text) > 0 else None)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=text)

View File

@ -7,12 +7,13 @@ import partial_json_parser
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser)
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix
from vllm.logger import init_logger
from vllm.utils import random_uuid
@ -41,6 +42,7 @@ def is_complete_json(input_str):
return False
@ToolParserManager.register_module("llama3_json")
class Llama3JsonToolParser(ToolParser):
"""
Tool call parser for Llama 3.1 models intended for use with the
@ -64,8 +66,9 @@ class Llama3JsonToolParser(ToolParser):
add_special_tokens=False)[0]
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
def extract_tool_calls(self,
model_output: str) -> ExtractedToolCallInformation:
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
@ -125,6 +128,7 @@ class Llama3JsonToolParser(ToolParser):
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if not (current_text.startswith(self.bot_token)

View File

@ -8,12 +8,13 @@ import partial_json_parser
from partial_json_parser.core.options import Allow
from pydantic import Field
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser)
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
@ -36,6 +37,7 @@ class MistralToolCall(ToolCall):
return "".join(choices(ALPHANUMERIC, k=9))
@ToolParserManager.register_module("mistral")
class MistralToolParser(ToolParser):
"""
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
@ -47,9 +49,7 @@ class MistralToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
self.model_tokenizer = self.model_tokenizer.tokenizer
else:
if not isinstance(self.model_tokenizer, MistralTokenizer):
logger.info("Non-Mistral tokenizer detected when using a Mistral "
"model...")
@ -61,11 +61,14 @@ class MistralToolParser(ToolParser):
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.model_tokenizer.vocab[self.bot_token]
self.bot_token_id = self.model_tokenizer.get_vocab()[self.bot_token]
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
def extract_tool_calls(self,
model_output: str) -> ExtractedToolCallInformation:
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response. Requires
find-and-replacing single quotes with double quotes for JSON parsing,
@ -119,6 +122,7 @@ class MistralToolParser(ToolParser):
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# if the tool call token is not in the tokens generated so far, append