[Bugfix] Multiple fixes to tool streaming with hermes and mistral (#10979)

Signed-off-by: cedonley <clayton@donley.io>
This commit is contained in:
Clayton 2024-12-11 17:10:12 -08:00 committed by GitHub
parent 4e11683368
commit 7439a8b5fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 69 additions and 21 deletions

View File

@ -496,21 +496,33 @@ class OpenAIServingChat(OpenAIServing):
if self._should_check_for_unstreamed_tool_arg_tokens( if self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output) and tool_parser: delta_message, output) and tool_parser:
latest_delta_len = 0
if ((isinstance(
delta_message.tool_calls[0].function,
DeltaFunctionCall)) and isinstance(
delta_message.tool_calls[0].function.
arguments, str)):
latest_delta_len = len(
delta_message.tool_calls[0].function.
arguments)
# get the expected call based on partial JSON # get the expected call based on partial JSON
# parsing which "autocompletes" the JSON # parsing which "autocompletes" the JSON
expected_call = json.dumps( expected_call = json.dumps(
tool_parser.prev_tool_call_arr[index].get( tool_parser.prev_tool_call_arr[index].get(
"arguments", {})) "arguments", {}),
ensure_ascii=False)
# get what we've streamed so far for arguments # get what we've streamed so far for arguments
# for the current tool # for the current tool
actual_call = tool_parser.streamed_args_for_tool[ actual_call = tool_parser.streamed_args_for_tool[
index] index]
if (latest_delta_len > 0):
actual_call = actual_call[:-latest_delta_len]
# check to see if there's anything left to stream # check to see if there's anything left to stream
remaining_call = expected_call.replace( remaining_call = expected_call.replace(
actual_call, "", 1) actual_call, "", 1)
# set that as a delta message # set that as a delta message
delta_message = DeltaMessage(tool_calls=[ delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(index=index, DeltaToolCall(index=index,

View File

@ -91,7 +91,8 @@ class Hermes2ProToolParser(ToolParser):
function=FunctionCall( function=FunctionCall(
name=function_call["name"], name=function_call["name"],
# function call args are JSON but as a string # function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"]))) arguments=json.dumps(function_call["arguments"],
ensure_ascii=False)))
for function_call in raw_function_calls for function_call in raw_function_calls
] ]
@ -139,14 +140,27 @@ class Hermes2ProToolParser(ToolParser):
self.tool_call_start_token_id) self.tool_call_start_token_id)
cur_tool_end_count = current_token_ids.count( cur_tool_end_count = current_token_ids.count(
self.tool_call_end_token_id) self.tool_call_end_token_id)
tool_call_portion = None
text_portion = None
# case: if we're generating text, OR rounding out a tool call # case: if we're generating text, OR rounding out a tool call
if (cur_tool_start_count == cur_tool_end_count if (cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count): and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text):
logger.debug("Generating text content! skipping tool parsing.") logger.debug("Generating text content! skipping tool parsing.")
if delta_text != self.tool_call_end_token:
return DeltaMessage(content=delta_text) return DeltaMessage(content=delta_text)
if self.tool_call_end_token in delta_text:
logger.debug("tool_call_end_token in delta_text")
full_text = current_text + delta_text
tool_call_portion = full_text.split(
self.tool_call_start_token)[-1].split(
self.tool_call_end_token)[0].rstrip()
delta_text = delta_text.split(
self.tool_call_end_token)[0].rstrip()
text_portion = delta_text.split(
self.tool_call_end_token)[-1].lstrip()
# case: if tool open & close tag counts don't match, we're doing # case: if tool open & close tag counts don't match, we're doing
# imaginary "else" block here # imaginary "else" block here
# something with tools with this diff. # something with tools with this diff.
@ -184,15 +198,21 @@ class Hermes2ProToolParser(ToolParser):
# case -- the current tool call is being closed. # case -- the current tool call is being closed.
elif (cur_tool_start_count == cur_tool_end_count elif (cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count > prev_tool_end_count): and cur_tool_end_count >= prev_tool_end_count):
if (self.prev_tool_call_arr is None
or len(self.prev_tool_call_arr) == 0):
logger.debug(
"attempting to close tool call, but no tool call")
return None
diff = self.prev_tool_call_arr[self.current_tool_id].get( diff = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments") "arguments")
if diff: if diff:
diff = diff.encode('utf-8').decode( diff = diff.encode('utf-8').decode(
'unicode_escape') if diff is str else diff 'unicode_escape') if diff is str else diff
diff = json.dumps( if ('"}' not in delta_text):
diff, ensure_ascii=False return None
)[len(self.streamed_args_for_tool[self.current_tool_id]):] end_loc = delta_text.rindex('"}')
diff = delta_text[:end_loc] + '"}'
logger.debug( logger.debug(
"Finishing tool and found diff that had not " "Finishing tool and found diff that had not "
"been streamed yet: %s", diff) "been streamed yet: %s", diff)
@ -221,10 +241,15 @@ class Hermes2ProToolParser(ToolParser):
except partial_json_parser.core.exceptions.MalformedJSON: except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet') logger.debug('not enough tokens to parse into JSON yet')
return None return None
except json.decoder.JSONDecodeError:
logger.debug("unable to parse JSON")
return None
# case - we haven't sent the tool name yet. If it's available, send # case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available. # it. otherwise, wait until it's available.
if not self.current_tool_name_sent: if not self.current_tool_name_sent:
if (current_tool_call is None):
return None
function_name: Union[str, None] = current_tool_call.get("name") function_name: Union[str, None] = current_tool_call.get("name")
if function_name: if function_name:
self.current_tool_name_sent = True self.current_tool_name_sent = True
@ -284,13 +309,17 @@ class Hermes2ProToolParser(ToolParser):
# autocompleting the JSON # autocompleting the JSON
elif cur_arguments and not prev_arguments: elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments) cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
logger.debug("finding %s in %s", delta_text, logger.debug("finding %s in %s", delta_text,
cur_arguments_json) cur_arguments_json)
# get the location where previous args differ from current # get the location where previous args differ from current
args_delta_start_loc = cur_arguments_json.index(delta_text) \ if (delta_text not in cur_arguments_json[:-2]):
+ len(delta_text) return None
args_delta_start_loc = cur_arguments_json[:-2]. \
rindex(delta_text) + \
len(delta_text)
# use that to find the actual delta # use that to find the actual delta
arguments_delta = cur_arguments_json[:args_delta_start_loc] arguments_delta = cur_arguments_json[:args_delta_start_loc]

View File

@ -19,7 +19,6 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff) extract_intermediate_diff)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@ -109,7 +108,8 @@ class MistralToolParser(ToolParser):
function=FunctionCall( function=FunctionCall(
name=raw_function_call["name"], name=raw_function_call["name"],
# function call args are JSON but as a string # function call args are JSON but as a string
arguments=json.dumps(raw_function_call["arguments"]))) arguments=json.dumps(raw_function_call["arguments"],
ensure_ascii=False)))
for raw_function_call in function_call_arr for raw_function_call in function_call_arr
] ]
@ -199,7 +199,7 @@ class MistralToolParser(ToolParser):
diff: Union[str, None] = current_tool_call.get("arguments") diff: Union[str, None] = current_tool_call.get("arguments")
if diff: if diff:
diff = json.dumps(diff).replace( diff = json.dumps(diff, ensure_ascii=False).replace(
self.streamed_args_for_tool[self.current_tool_id], self.streamed_args_for_tool[self.current_tool_id],
"") "")
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
@ -232,7 +232,7 @@ class MistralToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id, DeltaToolCall(index=self.current_tool_id,
type="function", type="function",
id=f"chatcmpl-tool-{random_uuid()}", id=MistralToolCall.generate_random_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True)) exclude_none=True))
@ -250,6 +250,8 @@ class MistralToolParser(ToolParser):
cur_arguments = current_tool_call.get("arguments") cur_arguments = current_tool_call.get("arguments")
new_text = delta_text.replace("\'", "\"") new_text = delta_text.replace("\'", "\"")
if ('"}' in new_text):
new_text = new_text[:new_text.rindex('"}')]
if not cur_arguments and not prev_arguments: if not cur_arguments and not prev_arguments:
@ -260,12 +262,15 @@ class MistralToolParser(ToolParser):
"mid-arguments") "mid-arguments")
delta = None delta = None
elif cur_arguments and not prev_arguments: elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments) cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)[:-2]
logger.debug("finding %s in %s", new_text, logger.debug("finding %s in %s", new_text,
cur_arguments_json) cur_arguments_json)
if (new_text not in cur_arguments_json):
return None
arguments_delta = cur_arguments_json[:cur_arguments_json. arguments_delta = cur_arguments_json[:cur_arguments_json.
index(new_text) + rindex(new_text) +
len(new_text)] len(new_text)]
logger.debug("First tokens in arguments received: %s", logger.debug("First tokens in arguments received: %s",
arguments_delta) arguments_delta)
@ -279,8 +284,10 @@ class MistralToolParser(ToolParser):
self.current_tool_id] += arguments_delta self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments: elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments) cur_args_json = json.dumps(cur_arguments,
prev_args_json = json.dumps(prev_arguments) ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
logger.debug("Searching for diff between \n%s\n%s", logger.debug("Searching for diff between \n%s\n%s",
cur_args_json, prev_args_json) cur_args_json, prev_args_json)