import json from typing import Generator, List, Optional import partial_json_parser import pytest from partial_json_parser.core.options import Allow from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers import JambaToolParser from vllm.transformers_utils.detokenizer import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer MODEL = "ai21labs/Jamba-tiny-dev" @pytest.fixture(scope="module") def jamba_tokenizer(): return get_tokenizer(tokenizer_name=MODEL) @pytest.fixture def jamba_tool_parser(jamba_tokenizer): return JambaToolParser(jamba_tokenizer) def assert_tool_calls(actual_tool_calls: List[ToolCall], expected_tool_calls: List[ToolCall]): assert len(actual_tool_calls) == len(expected_tool_calls) for actual_tool_call, expected_tool_call in zip(actual_tool_calls, expected_tool_calls): assert isinstance(actual_tool_call.id, str) assert len(actual_tool_call.id) > 16 assert actual_tool_call.type == "function" assert actual_tool_call.function == expected_tool_call.function def stream_delta_message_generator( jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer, model_output: str) -> Generator[DeltaMessage, None, None]: all_token_ids = jamba_tokenizer.encode(model_output, add_special_tokens=False) previous_text = "" previous_tokens = None prefix_offset = 0 read_offset = 0 for i, delta_token in enumerate(all_token_ids): delta_token_ids = [delta_token] previous_token_ids = all_token_ids[:i] current_token_ids = all_token_ids[:i + 1] (new_tokens, delta_text, new_prefix_offset, new_read_offset) = detokenize_incrementally( tokenizer=jamba_tokenizer, all_input_ids=current_token_ids, prev_tokens=previous_tokens, prefix_offset=prefix_offset, read_offset=read_offset, skip_special_tokens=False, spaces_between_special_tokens=True, ) current_text = previous_text + delta_text delta_message = jamba_tool_parser.extract_tool_calls_streaming( previous_text, current_text, delta_text, previous_token_ids, current_token_ids, delta_token_ids, request=None, # type: ignore[arg-type] ) if delta_message: yield delta_message previous_text = current_text previous_tokens = previous_tokens + new_tokens if previous_tokens\ else new_tokens prefix_offset = new_prefix_offset read_offset = new_read_offset def test_extract_tool_calls_no_tools(jamba_tool_parser): model_output = "This is a test" extracted_tool_calls = jamba_tool_parser.extract_tool_calls( model_output, request=None) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content == model_output @pytest.mark.parametrize( ids=[ "single_tool", "single_tool_with_content", "parallel_tools", ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ ( ''' [\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]''', # noqa: E501 [ ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps( { "city": "Dallas", "state": "TX", "unit": "fahrenheit" }))) ], None), ( ''' Sure! let me call the tool for you.[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]''', # noqa: E501 [ ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps( { "city": "Dallas", "state": "TX", "unit": "fahrenheit" }))) ], " Sure! let me call the tool for you."), ( ''' [\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]''', # noqa: E501 [ ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps( { "city": "Dallas", "state": "TX", "unit": "fahrenheit" }))), ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps( { "city": "Orlando", "state": "FL", "unit": "fahrenheit" }))) ], None) ], ) def test_extract_tool_calls(jamba_tool_parser, model_output, expected_tool_calls, expected_content): extracted_tool_calls = jamba_tool_parser.extract_tool_calls( model_output, request=None) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) assert extracted_tool_calls.content == expected_content @pytest.mark.parametrize( ids=[ "no_tools", "single_tool", "single_tool_with_content", "parallel_tools", ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ ('''This is a test''', [], '''This is a test'''), ( ''' [\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]''', # noqa: E501 [ ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps( { "city": "Dallas", "state": "TX", "unit": "fahrenheit" }))) ], " "), ( ''' Sure! let me call the tool for you.[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]''', # noqa: E501 [ ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps( { "city": "Dallas", "state": "TX", "unit": "fahrenheit" }))) ], " Sure! let me call the tool for you."), ( ''' [\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]''', # noqa: E501 [ ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps( { "city": "Dallas", "state": "TX", "unit": "fahrenheit" }))), ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps( { "city": "Orlando", "state": "FL", "unit": "fahrenheit" }))) ], " ") ], ) def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer, model_output, expected_tool_calls, expected_content): other_content: str = '' function_names: List[str] = [] function_args_strs: List[str] = [] tool_call_idx: int = -1 tool_call_ids: List[Optional[str]] = [] for delta_message in stream_delta_message_generator( jamba_tool_parser, jamba_tokenizer, model_output): # role should never be streamed from tool parser assert not delta_message.role if delta_message.content: other_content += delta_message.content streamed_tool_calls = delta_message.tool_calls if streamed_tool_calls and len(streamed_tool_calls) > 0: # make sure only one diff is present - correct even for parallel assert len(streamed_tool_calls) == 1 tool_call = streamed_tool_calls[0] # if a new tool is being called, set up empty arguments if tool_call.index != tool_call_idx: tool_call_idx = tool_call.index function_args_strs.append("") tool_call_ids.append(None) # if a tool call ID is streamed, make sure one hasn't been already if tool_call.id and not tool_call_ids[tool_call.index]: tool_call_ids[tool_call.index] = tool_call.id # if parts of the function start being streamed if tool_call.function: # if the function name is defined, set it. it should be streamed # IN ENTIRETY, exactly one time. if tool_call.function.name: assert isinstance(tool_call.function.name, str) function_names.append(tool_call.function.name) if tool_call.function.arguments: # make sure they're a string and then add them to the list assert isinstance(tool_call.function.arguments, str) function_args_strs[ tool_call.index] += tool_call.function.arguments assert other_content == expected_content actual_tool_calls = [ ToolCall(id=tool_call_id, function=FunctionCall( name=function_name, arguments=partial_json_parser.ensure_json( function_args_str, Allow.OBJ | Allow.STR))) for tool_call_id, function_name, function_args_str in zip( tool_call_ids, function_names, function_args_strs) ] assert_tool_calls(actual_tool_calls, expected_tool_calls)