vllm/tests/tokenization/test_mistral_tokenizer.py

84 lines
2.6 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import pytest
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import Function, Tool
from vllm.transformers_utils.tokenizers.mistral import (
make_mistral_chat_completion_request)
# yapf: enable
@pytest.mark.parametrize(
"openai_request,expected_mistral_request",
[(
{
"messages": [{
"role": "user",
"content": "What is the current local date and time?",
}],
"tools": [{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
},
}],
},
ChatCompletionRequest(
messages=[
UserMessage(content="What is the current local date and time?")
],
tools=[
Tool(
type="function",
function=Function(
name="get_current_time",
description="Fetch the current local date and time.",
parameters={},
),
)
],
),
),
(
{
"messages":
[{
"role": "user",
"content": "What is the current local date and time?",
}],
"tools": [{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
"parameters": None,
},
}],
},
ChatCompletionRequest(
messages=[
UserMessage(
content="What is the current local date and time?")
],
tools=[
Tool(
type="function",
function=Function(
name="get_current_time",
description="Fetch the current local date and time.",
parameters={},
),
)
],
),
)],
)
def test_make_mistral_chat_completion_request(openai_request,
expected_mistral_request):
assert (make_mistral_chat_completion_request(
openai_request["messages"],
openai_request["tools"]) == expected_mistral_request)