
Co-authored-by: constellate <constellate@1-ai-appserver-staging.codereach.com> Co-authored-by: Kyle Mistele <kyle@constellate.ai>
194 lines
6.9 KiB
Python
194 lines
6.9 KiB
Python
import json
|
|
from typing import Dict, List, Optional
|
|
|
|
import openai
|
|
import pytest
|
|
|
|
from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
|
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL,
|
|
WEATHER_TOOL)
|
|
|
|
|
|
# test: getting the model to generate parallel tool calls (streaming/not)
|
|
# when requested. NOTE that not all models may support this, so some exclusions
|
|
# may be added in the future. e.g. llama 3.1 models are not designed to support
|
|
# parallel tool calls.
|
|
@pytest.mark.asyncio
|
|
async def test_parallel_tool_calls(client: openai.AsyncOpenAI):
|
|
models = await client.models.list()
|
|
model_name: str = models.data[0].id
|
|
chat_completion = await client.chat.completions.create(
|
|
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
|
temperature=0,
|
|
max_tokens=200,
|
|
model=model_name,
|
|
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
|
logprobs=False)
|
|
|
|
choice = chat_completion.choices[0]
|
|
stop_reason = chat_completion.choices[0].finish_reason
|
|
non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
|
|
|
|
# make sure 2 tool calls are present
|
|
assert choice.message.role == "assistant"
|
|
assert non_streamed_tool_calls is not None
|
|
assert len(non_streamed_tool_calls) == 2
|
|
|
|
for tool_call in non_streamed_tool_calls:
|
|
# make sure the tool includes a function and ID
|
|
assert tool_call.type == "function"
|
|
assert tool_call.function is not None
|
|
assert isinstance(tool_call.id, str)
|
|
assert len(tool_call.id) > 16
|
|
|
|
# make sure the weather tool was called correctly
|
|
assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
|
|
assert isinstance(tool_call.function.arguments, str)
|
|
|
|
parsed_arguments = json.loads(tool_call.function.arguments)
|
|
assert isinstance(parsed_arguments, Dict)
|
|
assert isinstance(parsed_arguments.get("city"), str)
|
|
assert isinstance(parsed_arguments.get("state"), str)
|
|
|
|
assert stop_reason == "tool_calls"
|
|
|
|
# make the same request, streaming
|
|
stream = await client.chat.completions.create(
|
|
model=model_name,
|
|
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
|
temperature=0,
|
|
max_tokens=200,
|
|
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
|
logprobs=False,
|
|
stream=True)
|
|
|
|
role_name: Optional[str] = None
|
|
finish_reason_count: int = 0
|
|
|
|
tool_call_names: List[str] = []
|
|
tool_call_args: List[str] = []
|
|
tool_call_idx: int = -1
|
|
tool_call_id_count: int = 0
|
|
|
|
async for chunk in stream:
|
|
|
|
# if there's a finish reason make sure it's tools
|
|
if chunk.choices[0].finish_reason:
|
|
finish_reason_count += 1
|
|
assert chunk.choices[0].finish_reason == 'tool_calls'
|
|
|
|
# if a role is being streamed make sure it wasn't already set to
|
|
# something else
|
|
if chunk.choices[0].delta.role:
|
|
assert not role_name or role_name == 'assistant'
|
|
role_name = 'assistant'
|
|
|
|
# if a tool call is streamed make sure there's exactly one
|
|
# (based on the request parameters
|
|
streamed_tool_calls = chunk.choices[0].delta.tool_calls
|
|
|
|
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
|
|
|
# make sure only one diff is present - correct even for parallel
|
|
assert len(streamed_tool_calls) == 1
|
|
tool_call = streamed_tool_calls[0]
|
|
|
|
# if a new tool is being called, set up empty arguments
|
|
if tool_call.index != tool_call_idx:
|
|
tool_call_idx = tool_call.index
|
|
tool_call_args.append("")
|
|
|
|
# if a tool call ID is streamed, make sure one hasn't been already
|
|
if tool_call.id:
|
|
tool_call_id_count += 1
|
|
assert (isinstance(tool_call.id, str)
|
|
and (len(tool_call.id) > 16))
|
|
|
|
# if parts of the function start being streamed
|
|
if tool_call.function:
|
|
# if the function name is defined, set it. it should be streamed
|
|
# IN ENTIRETY, exactly one time.
|
|
if tool_call.function.name:
|
|
assert isinstance(tool_call.function.name, str)
|
|
tool_call_names.append(tool_call.function.name)
|
|
|
|
if tool_call.function.arguments:
|
|
# make sure they're a string and then add them to the list
|
|
assert isinstance(tool_call.function.arguments, str)
|
|
|
|
tool_call_args[
|
|
tool_call.index] += tool_call.function.arguments
|
|
|
|
assert finish_reason_count == 1
|
|
assert role_name == 'assistant'
|
|
|
|
assert (len(non_streamed_tool_calls) == len(tool_call_names) ==
|
|
len(tool_call_args))
|
|
|
|
for i in range(2):
|
|
assert non_streamed_tool_calls[i].function.name == tool_call_names[i]
|
|
streamed_args = json.loads(tool_call_args[i])
|
|
non_streamed_args = json.loads(
|
|
non_streamed_tool_calls[i].function.arguments)
|
|
assert streamed_args == non_streamed_args
|
|
|
|
|
|
# test: providing parallel tool calls back to the model to get a response
|
|
# (streaming/not)
|
|
@pytest.mark.asyncio
|
|
async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI):
|
|
models = await client.models.list()
|
|
model_name: str = models.data[0].id
|
|
chat_completion = await client.chat.completions.create(
|
|
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
|
|
temperature=0,
|
|
max_tokens=200,
|
|
model=model_name,
|
|
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
|
logprobs=False)
|
|
|
|
choice = chat_completion.choices[0]
|
|
|
|
assert choice.finish_reason != "tool_calls" # "stop" or "length"
|
|
assert choice.message.role == "assistant"
|
|
assert choice.message.tool_calls is None \
|
|
or len(choice.message.tool_calls) == 0
|
|
assert choice.message.content is not None
|
|
assert "98" in choice.message.content # Dallas temp in tool response
|
|
assert "78" in choice.message.content # Orlando temp in tool response
|
|
|
|
stream = await client.chat.completions.create(
|
|
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
|
|
temperature=0,
|
|
max_tokens=200,
|
|
model=model_name,
|
|
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
|
logprobs=False,
|
|
stream=True)
|
|
|
|
chunks: List[str] = []
|
|
finish_reason_count = 0
|
|
role_sent: bool = False
|
|
|
|
async for chunk in stream:
|
|
delta = chunk.choices[0].delta
|
|
|
|
if delta.role:
|
|
assert not role_sent
|
|
assert delta.role == "assistant"
|
|
role_sent = True
|
|
|
|
if delta.content:
|
|
chunks.append(delta.content)
|
|
|
|
if chunk.choices[0].finish_reason is not None:
|
|
finish_reason_count += 1
|
|
assert chunk.choices[0].finish_reason == choice.finish_reason
|
|
|
|
assert not delta.tool_calls or len(delta.tool_calls) == 0
|
|
|
|
assert role_sent
|
|
assert finish_reason_count == 1
|
|
assert len(chunks)
|
|
assert "".join(chunks) == choice.message.content
|