[Bugfix] Multiple fixes to tool streaming with hermes and mistral (#10979)
Signed-off-by: cedonley <clayton@donley.io>
This commit is contained in:
parent
4e11683368
commit
7439a8b5fc
@ -496,21 +496,33 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
if self._should_check_for_unstreamed_tool_arg_tokens(
|
if self._should_check_for_unstreamed_tool_arg_tokens(
|
||||||
delta_message, output) and tool_parser:
|
delta_message, output) and tool_parser:
|
||||||
|
latest_delta_len = 0
|
||||||
|
if ((isinstance(
|
||||||
|
delta_message.tool_calls[0].function,
|
||||||
|
DeltaFunctionCall)) and isinstance(
|
||||||
|
delta_message.tool_calls[0].function.
|
||||||
|
arguments, str)):
|
||||||
|
latest_delta_len = len(
|
||||||
|
delta_message.tool_calls[0].function.
|
||||||
|
arguments)
|
||||||
|
|
||||||
# get the expected call based on partial JSON
|
# get the expected call based on partial JSON
|
||||||
# parsing which "autocompletes" the JSON
|
# parsing which "autocompletes" the JSON
|
||||||
expected_call = json.dumps(
|
expected_call = json.dumps(
|
||||||
tool_parser.prev_tool_call_arr[index].get(
|
tool_parser.prev_tool_call_arr[index].get(
|
||||||
"arguments", {}))
|
"arguments", {}),
|
||||||
|
ensure_ascii=False)
|
||||||
|
|
||||||
# get what we've streamed so far for arguments
|
# get what we've streamed so far for arguments
|
||||||
# for the current tool
|
# for the current tool
|
||||||
actual_call = tool_parser.streamed_args_for_tool[
|
actual_call = tool_parser.streamed_args_for_tool[
|
||||||
index]
|
index]
|
||||||
|
if (latest_delta_len > 0):
|
||||||
|
actual_call = actual_call[:-latest_delta_len]
|
||||||
|
|
||||||
# check to see if there's anything left to stream
|
# check to see if there's anything left to stream
|
||||||
remaining_call = expected_call.replace(
|
remaining_call = expected_call.replace(
|
||||||
actual_call, "", 1)
|
actual_call, "", 1)
|
||||||
|
|
||||||
# set that as a delta message
|
# set that as a delta message
|
||||||
delta_message = DeltaMessage(tool_calls=[
|
delta_message = DeltaMessage(tool_calls=[
|
||||||
DeltaToolCall(index=index,
|
DeltaToolCall(index=index,
|
||||||
|
@ -91,7 +91,8 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
function=FunctionCall(
|
function=FunctionCall(
|
||||||
name=function_call["name"],
|
name=function_call["name"],
|
||||||
# function call args are JSON but as a string
|
# function call args are JSON but as a string
|
||||||
arguments=json.dumps(function_call["arguments"])))
|
arguments=json.dumps(function_call["arguments"],
|
||||||
|
ensure_ascii=False)))
|
||||||
for function_call in raw_function_calls
|
for function_call in raw_function_calls
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -139,14 +140,27 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
self.tool_call_start_token_id)
|
self.tool_call_start_token_id)
|
||||||
cur_tool_end_count = current_token_ids.count(
|
cur_tool_end_count = current_token_ids.count(
|
||||||
self.tool_call_end_token_id)
|
self.tool_call_end_token_id)
|
||||||
|
tool_call_portion = None
|
||||||
|
text_portion = None
|
||||||
|
|
||||||
# case: if we're generating text, OR rounding out a tool call
|
# case: if we're generating text, OR rounding out a tool call
|
||||||
if (cur_tool_start_count == cur_tool_end_count
|
if (cur_tool_start_count == cur_tool_end_count
|
||||||
and prev_tool_end_count == cur_tool_end_count):
|
and prev_tool_end_count == cur_tool_end_count
|
||||||
|
and self.tool_call_end_token not in delta_text):
|
||||||
logger.debug("Generating text content! skipping tool parsing.")
|
logger.debug("Generating text content! skipping tool parsing.")
|
||||||
if delta_text != self.tool_call_end_token:
|
|
||||||
return DeltaMessage(content=delta_text)
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
|
if self.tool_call_end_token in delta_text:
|
||||||
|
logger.debug("tool_call_end_token in delta_text")
|
||||||
|
full_text = current_text + delta_text
|
||||||
|
tool_call_portion = full_text.split(
|
||||||
|
self.tool_call_start_token)[-1].split(
|
||||||
|
self.tool_call_end_token)[0].rstrip()
|
||||||
|
delta_text = delta_text.split(
|
||||||
|
self.tool_call_end_token)[0].rstrip()
|
||||||
|
text_portion = delta_text.split(
|
||||||
|
self.tool_call_end_token)[-1].lstrip()
|
||||||
|
|
||||||
# case: if tool open & close tag counts don't match, we're doing
|
# case: if tool open & close tag counts don't match, we're doing
|
||||||
# imaginary "else" block here
|
# imaginary "else" block here
|
||||||
# something with tools with this diff.
|
# something with tools with this diff.
|
||||||
@ -184,15 +198,21 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
|
|
||||||
# case -- the current tool call is being closed.
|
# case -- the current tool call is being closed.
|
||||||
elif (cur_tool_start_count == cur_tool_end_count
|
elif (cur_tool_start_count == cur_tool_end_count
|
||||||
and cur_tool_end_count > prev_tool_end_count):
|
and cur_tool_end_count >= prev_tool_end_count):
|
||||||
|
if (self.prev_tool_call_arr is None
|
||||||
|
or len(self.prev_tool_call_arr) == 0):
|
||||||
|
logger.debug(
|
||||||
|
"attempting to close tool call, but no tool call")
|
||||||
|
return None
|
||||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||||
"arguments")
|
"arguments")
|
||||||
if diff:
|
if diff:
|
||||||
diff = diff.encode('utf-8').decode(
|
diff = diff.encode('utf-8').decode(
|
||||||
'unicode_escape') if diff is str else diff
|
'unicode_escape') if diff is str else diff
|
||||||
diff = json.dumps(
|
if ('"}' not in delta_text):
|
||||||
diff, ensure_ascii=False
|
return None
|
||||||
)[len(self.streamed_args_for_tool[self.current_tool_id]):]
|
end_loc = delta_text.rindex('"}')
|
||||||
|
diff = delta_text[:end_loc] + '"}'
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Finishing tool and found diff that had not "
|
"Finishing tool and found diff that had not "
|
||||||
"been streamed yet: %s", diff)
|
"been streamed yet: %s", diff)
|
||||||
@ -221,10 +241,15 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||||
logger.debug('not enough tokens to parse into JSON yet')
|
logger.debug('not enough tokens to parse into JSON yet')
|
||||||
return None
|
return None
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
logger.debug("unable to parse JSON")
|
||||||
|
return None
|
||||||
|
|
||||||
# case - we haven't sent the tool name yet. If it's available, send
|
# case - we haven't sent the tool name yet. If it's available, send
|
||||||
# it. otherwise, wait until it's available.
|
# it. otherwise, wait until it's available.
|
||||||
if not self.current_tool_name_sent:
|
if not self.current_tool_name_sent:
|
||||||
|
if (current_tool_call is None):
|
||||||
|
return None
|
||||||
function_name: Union[str, None] = current_tool_call.get("name")
|
function_name: Union[str, None] = current_tool_call.get("name")
|
||||||
if function_name:
|
if function_name:
|
||||||
self.current_tool_name_sent = True
|
self.current_tool_name_sent = True
|
||||||
@ -284,13 +309,17 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
# autocompleting the JSON
|
# autocompleting the JSON
|
||||||
elif cur_arguments and not prev_arguments:
|
elif cur_arguments and not prev_arguments:
|
||||||
|
|
||||||
cur_arguments_json = json.dumps(cur_arguments)
|
cur_arguments_json = json.dumps(cur_arguments,
|
||||||
|
ensure_ascii=False)
|
||||||
logger.debug("finding %s in %s", delta_text,
|
logger.debug("finding %s in %s", delta_text,
|
||||||
cur_arguments_json)
|
cur_arguments_json)
|
||||||
|
|
||||||
# get the location where previous args differ from current
|
# get the location where previous args differ from current
|
||||||
args_delta_start_loc = cur_arguments_json.index(delta_text) \
|
if (delta_text not in cur_arguments_json[:-2]):
|
||||||
+ len(delta_text)
|
return None
|
||||||
|
args_delta_start_loc = cur_arguments_json[:-2]. \
|
||||||
|
rindex(delta_text) + \
|
||||||
|
len(delta_text)
|
||||||
|
|
||||||
# use that to find the actual delta
|
# use that to find the actual delta
|
||||||
arguments_delta = cur_arguments_json[:args_delta_start_loc]
|
arguments_delta = cur_arguments_json[:args_delta_start_loc]
|
||||||
|
@ -19,7 +19,6 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
|
|||||||
extract_intermediate_diff)
|
extract_intermediate_diff)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import random_uuid
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -109,7 +108,8 @@ class MistralToolParser(ToolParser):
|
|||||||
function=FunctionCall(
|
function=FunctionCall(
|
||||||
name=raw_function_call["name"],
|
name=raw_function_call["name"],
|
||||||
# function call args are JSON but as a string
|
# function call args are JSON but as a string
|
||||||
arguments=json.dumps(raw_function_call["arguments"])))
|
arguments=json.dumps(raw_function_call["arguments"],
|
||||||
|
ensure_ascii=False)))
|
||||||
for raw_function_call in function_call_arr
|
for raw_function_call in function_call_arr
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -199,7 +199,7 @@ class MistralToolParser(ToolParser):
|
|||||||
diff: Union[str, None] = current_tool_call.get("arguments")
|
diff: Union[str, None] = current_tool_call.get("arguments")
|
||||||
|
|
||||||
if diff:
|
if diff:
|
||||||
diff = json.dumps(diff).replace(
|
diff = json.dumps(diff, ensure_ascii=False).replace(
|
||||||
self.streamed_args_for_tool[self.current_tool_id],
|
self.streamed_args_for_tool[self.current_tool_id],
|
||||||
"")
|
"")
|
||||||
delta = DeltaMessage(tool_calls=[
|
delta = DeltaMessage(tool_calls=[
|
||||||
@ -232,7 +232,7 @@ class MistralToolParser(ToolParser):
|
|||||||
delta = DeltaMessage(tool_calls=[
|
delta = DeltaMessage(tool_calls=[
|
||||||
DeltaToolCall(index=self.current_tool_id,
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
type="function",
|
type="function",
|
||||||
id=f"chatcmpl-tool-{random_uuid()}",
|
id=MistralToolCall.generate_random_id(),
|
||||||
function=DeltaFunctionCall(
|
function=DeltaFunctionCall(
|
||||||
name=function_name).model_dump(
|
name=function_name).model_dump(
|
||||||
exclude_none=True))
|
exclude_none=True))
|
||||||
@ -250,6 +250,8 @@ class MistralToolParser(ToolParser):
|
|||||||
cur_arguments = current_tool_call.get("arguments")
|
cur_arguments = current_tool_call.get("arguments")
|
||||||
|
|
||||||
new_text = delta_text.replace("\'", "\"")
|
new_text = delta_text.replace("\'", "\"")
|
||||||
|
if ('"}' in new_text):
|
||||||
|
new_text = new_text[:new_text.rindex('"}')]
|
||||||
|
|
||||||
if not cur_arguments and not prev_arguments:
|
if not cur_arguments and not prev_arguments:
|
||||||
|
|
||||||
@ -260,12 +262,15 @@ class MistralToolParser(ToolParser):
|
|||||||
"mid-arguments")
|
"mid-arguments")
|
||||||
delta = None
|
delta = None
|
||||||
elif cur_arguments and not prev_arguments:
|
elif cur_arguments and not prev_arguments:
|
||||||
cur_arguments_json = json.dumps(cur_arguments)
|
cur_arguments_json = json.dumps(cur_arguments,
|
||||||
|
ensure_ascii=False)[:-2]
|
||||||
logger.debug("finding %s in %s", new_text,
|
logger.debug("finding %s in %s", new_text,
|
||||||
cur_arguments_json)
|
cur_arguments_json)
|
||||||
|
|
||||||
|
if (new_text not in cur_arguments_json):
|
||||||
|
return None
|
||||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||||
index(new_text) +
|
rindex(new_text) +
|
||||||
len(new_text)]
|
len(new_text)]
|
||||||
logger.debug("First tokens in arguments received: %s",
|
logger.debug("First tokens in arguments received: %s",
|
||||||
arguments_delta)
|
arguments_delta)
|
||||||
@ -279,8 +284,10 @@ class MistralToolParser(ToolParser):
|
|||||||
self.current_tool_id] += arguments_delta
|
self.current_tool_id] += arguments_delta
|
||||||
|
|
||||||
elif cur_arguments and prev_arguments:
|
elif cur_arguments and prev_arguments:
|
||||||
cur_args_json = json.dumps(cur_arguments)
|
cur_args_json = json.dumps(cur_arguments,
|
||||||
prev_args_json = json.dumps(prev_arguments)
|
ensure_ascii=False)
|
||||||
|
prev_args_json = json.dumps(prev_arguments,
|
||||||
|
ensure_ascii=False)
|
||||||
logger.debug("Searching for diff between \n%s\n%s",
|
logger.debug("Searching for diff between \n%s\n%s",
|
||||||
cur_args_json, prev_args_json)
|
cur_args_json, prev_args_json)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user