[Model] Add mistral function calling format to all models loaded with "mistral" format (#8515)

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Patrick von Platen 2024-09-17 19:50:37 +02:00 committed by GitHub
parent 9855b99502
commit a54ed80249
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 219 additions and 9 deletions

View File

@ -0,0 +1,138 @@
# ruff: noqa
import json
import random
import string
from vllm import LLM
from vllm.sampling_params import SamplingParams
# This script is an offline demo for function calling
#
# If you want to run a server/client setup, please follow this code:
#
# - Server:
#
# ```bash
# vllm serve mistralai/Mistral-7B-Instruct-v0.3 --tokenizer-mode mistral --load-format mistral --config-format mistral
# ```
#
# - Client:
#
# ```bash
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
# --header 'Content-Type: application/json' \
# --header 'Authorization: Bearer token' \
# --data '{
# "model": "mistralai/Mistral-7B-Instruct-v0.3"
# "messages": [
# {
# "role": "user",
# "content": [
# {"type" : "text", "text": "Describe this image in detail please."},
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
# {"type" : "text", "text": "and this one as well. Answer in French."},
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
# ]
# }
# ]
# }'
# ```
#
# Usage:
# python demo.py simple
# python demo.py advanced
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
# or switch to "mistralai/Mistral-Nemo-Instruct-2407"
# or "mistralai/Mistral-Large-Instruct-2407"
# or any other mistral model with function calling ability
sampling_params = SamplingParams(max_tokens=8192, temperature=0.0)
llm = LLM(model=model_name,
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral")
def generate_random_id(length=9):
characters = string.ascii_letters + string.digits
random_id = ''.join(random.choice(characters) for _ in range(length))
return random_id
# simulate an API that can be called
def get_current_weather(city: str, state: str, unit: 'str'):
return (f"The weather in {city}, {state} is 85 degrees {unit}. It is "
"partly cloudly, with highs in the 90's.")
tool_funtions = {"get_current_weather": get_current_weather}
tools = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city", "state", "unit"]
}
}
}]
messages = [{
"role":
"user",
"content":
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]
outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools)
output = outputs[0].outputs[0].text.strip()
# append the assistant message
messages.append({
"role": "assistant",
"content": output,
})
# let's now actually parse and execute the model's output simulating an API call by using the
# above defined function
tool_calls = json.loads(output)
tool_answers = [
tool_funtions[call['name']](**call['arguments']) for call in tool_calls
]
# append the answer as a tool message and let the LLM give you an answer
messages.append({
"role": "tool",
"content": "\n\n".join(tool_answers),
"tool_call_id": generate_random_id(),
})
outputs = llm.chat(messages, sampling_params, tools=tools)
print(outputs[0].outputs[0].text.strip())
# yields
# 'The weather in Dallas, TX is 85 degrees fahrenheit. '
# 'It is partly cloudly, with highs in the 90's.'

View File

@ -4,13 +4,61 @@ Run `pytest tests/models/test_mistral.py`.
""" """
import pytest import pytest
from vllm import SamplingParams
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
MODELS = [ MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Mistral-7B-Instruct-v0.3",
# Mistral-Nemo is to big for CI, but passes locally
# "mistralai/Mistral-Nemo-Instruct-2407"
] ]
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
# for function calling
TOOLS = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city", "state", "unit"]
}
}
}]
MSGS = [{
"role":
"user",
"content": ("Can you tell me what the temperate"
" will be in Dallas, in fahrenheit?")
}]
EXPECTED_FUNC_CALL = (
'[{"name": "get_current_weather", "arguments": '
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]')
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@ -81,3 +129,22 @@ def test_mistral_format(
name_0="hf", name_0="hf",
name_1="mistral", name_1="mistral",
) )
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling
def test_mistral_function_calling(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model,
dtype=dtype,
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral") as vllm_model:
outputs = vllm_model.model.chat(MSGS,
tools=TOOLS,
sampling_params=SAMPLING_PARAMS)
assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL

View File

@ -1,5 +1,6 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
overload)
from tqdm import tqdm from tqdm import tqdm
@ -357,6 +358,7 @@ class LLM:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
tools: Optional[List[Dict[str, Any]]] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
""" """
Generate responses for a chat conversation. Generate responses for a chat conversation.
@ -401,6 +403,7 @@ class LLM:
messages=messages, messages=messages,
chat_template=chat_template, chat_template=chat_template,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
tools=tools,
) )
else: else:
prompt = apply_hf_chat_template( prompt = apply_hf_chat_template(
@ -408,6 +411,7 @@ class LLM:
conversation=conversation, conversation=conversation,
chat_template=chat_template, chat_template=chat_template,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
tools=tools,
) )
inputs: PromptInputs inputs: PromptInputs

View File

@ -123,7 +123,8 @@ class OpenAIServingChat(OpenAIServing):
] ]
prompt: Union[str, List[int]] prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer): is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
prompt = apply_mistral_chat_template( prompt = apply_mistral_chat_template(
tokenizer, tokenizer,
messages=request.messages, messages=request.messages,
@ -159,10 +160,10 @@ class OpenAIServingChat(OpenAIServing):
return self.create_error_response( return self.create_error_response(
"tool_choice = \"required\" is not supported!") "tool_choice = \"required\" is not supported!")
# "auto" tools requires --enable-auto-tool-choice if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
# and --tool-call-parser
if request.tool_choice == "auto" and not (
self.enable_auto_tools and self.tool_parser is not None): self.enable_auto_tools and self.tool_parser is not None):
# for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser
return self.create_error_response( return self.create_error_response(
"\"auto\" tool choice requires " "\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set") "--enable-auto-tool-choice and --tool-call-parser to be set")

View File

@ -165,10 +165,9 @@ class MistralTokenizer:
messages: List["ChatCompletionMessageParam"], messages: List["ChatCompletionMessageParam"],
tools: Optional[Dict[str, Any]] = None, tools: Optional[Dict[str, Any]] = None,
**kwargs) -> List[int]: **kwargs) -> List[int]:
assert tools is None, "`tools` are not yet supported."
request = ChatCompletionRequest( request = ChatCompletionRequest(messages=messages,
messages=messages) # type: ignore[type-var] tools=tools) # type: ignore[type-var]
encoded = self.mistral.encode_chat_completion(request) encoded = self.mistral.encode_chat_completion(request)
# encode-decode to get clean prompt # encode-decode to get clean prompt
@ -176,7 +175,8 @@ class MistralTokenizer:
def convert_tokens_to_string(self, tokens: List[str]) -> str: def convert_tokens_to_string(self, tokens: List[str]) -> str:
if isinstance(self.tokenizer, Tekkenizer): if isinstance(self.tokenizer, Tekkenizer):
return "".join(tokens) return "".join(t for t in tokens
if t not in self.tokenizer._all_special_tokens)
else: else:
return self.tokenizer.decode(tokens) # type: ignore[arg-type] return self.tokenizer.decode(tokens) # type: ignore[arg-type]