[Model] Add mistral function calling format to all models loaded with "mistral" format (#8515)
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
9855b99502
commit
a54ed80249
138
examples/offline_chat_with_tools.py
Normal file
138
examples/offline_chat_with_tools.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
# ruff: noqa
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
# This script is an offline demo for function calling
|
||||||
|
#
|
||||||
|
# If you want to run a server/client setup, please follow this code:
|
||||||
|
#
|
||||||
|
# - Server:
|
||||||
|
#
|
||||||
|
# ```bash
|
||||||
|
# vllm serve mistralai/Mistral-7B-Instruct-v0.3 --tokenizer-mode mistral --load-format mistral --config-format mistral
|
||||||
|
# ```
|
||||||
|
#
|
||||||
|
# - Client:
|
||||||
|
#
|
||||||
|
# ```bash
|
||||||
|
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
|
||||||
|
# --header 'Content-Type: application/json' \
|
||||||
|
# --header 'Authorization: Bearer token' \
|
||||||
|
# --data '{
|
||||||
|
# "model": "mistralai/Mistral-7B-Instruct-v0.3"
|
||||||
|
# "messages": [
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": [
|
||||||
|
# {"type" : "text", "text": "Describe this image in detail please."},
|
||||||
|
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
|
||||||
|
# {"type" : "text", "text": "and this one as well. Answer in French."},
|
||||||
|
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# }'
|
||||||
|
# ```
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# python demo.py simple
|
||||||
|
# python demo.py advanced
|
||||||
|
|
||||||
|
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||||
|
# or switch to "mistralai/Mistral-Nemo-Instruct-2407"
|
||||||
|
# or "mistralai/Mistral-Large-Instruct-2407"
|
||||||
|
# or any other mistral model with function calling ability
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(max_tokens=8192, temperature=0.0)
|
||||||
|
llm = LLM(model=model_name,
|
||||||
|
tokenizer_mode="mistral",
|
||||||
|
config_format="mistral",
|
||||||
|
load_format="mistral")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_random_id(length=9):
|
||||||
|
characters = string.ascii_letters + string.digits
|
||||||
|
random_id = ''.join(random.choice(characters) for _ in range(length))
|
||||||
|
return random_id
|
||||||
|
|
||||||
|
|
||||||
|
# simulate an API that can be called
|
||||||
|
def get_current_weather(city: str, state: str, unit: 'str'):
|
||||||
|
return (f"The weather in {city}, {state} is 85 degrees {unit}. It is "
|
||||||
|
"partly cloudly, with highs in the 90's.")
|
||||||
|
|
||||||
|
|
||||||
|
tool_funtions = {"get_current_weather": get_current_weather}
|
||||||
|
|
||||||
|
tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type":
|
||||||
|
"string",
|
||||||
|
"description":
|
||||||
|
"The city to find the weather for, e.g. 'San Francisco'"
|
||||||
|
},
|
||||||
|
"state": {
|
||||||
|
"type":
|
||||||
|
"string",
|
||||||
|
"description":
|
||||||
|
"the two-letter abbreviation for the state that the city is"
|
||||||
|
" in, e.g. 'CA' which would mean 'California'"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The unit to fetch the temperature in",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city", "state", "unit"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools)
|
||||||
|
output = outputs[0].outputs[0].text.strip()
|
||||||
|
|
||||||
|
# append the assistant message
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": output,
|
||||||
|
})
|
||||||
|
|
||||||
|
# let's now actually parse and execute the model's output simulating an API call by using the
|
||||||
|
# above defined function
|
||||||
|
tool_calls = json.loads(output)
|
||||||
|
tool_answers = [
|
||||||
|
tool_funtions[call['name']](**call['arguments']) for call in tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
# append the answer as a tool message and let the LLM give you an answer
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"content": "\n\n".join(tool_answers),
|
||||||
|
"tool_call_id": generate_random_id(),
|
||||||
|
})
|
||||||
|
|
||||||
|
outputs = llm.chat(messages, sampling_params, tools=tools)
|
||||||
|
|
||||||
|
print(outputs[0].outputs[0].text.strip())
|
||||||
|
# yields
|
||||||
|
# 'The weather in Dallas, TX is 85 degrees fahrenheit. '
|
||||||
|
# 'It is partly cloudly, with highs in the 90's.'
|
@ -4,13 +4,61 @@ Run `pytest tests/models/test_mistral.py`.
|
|||||||
"""
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm import SamplingParams
|
||||||
|
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"mistralai/Mistral-7B-Instruct-v0.1",
|
"mistralai/Mistral-7B-Instruct-v0.1",
|
||||||
"mistralai/Mistral-7B-Instruct-v0.3",
|
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||||
|
# Mistral-Nemo is to big for CI, but passes locally
|
||||||
|
# "mistralai/Mistral-Nemo-Instruct-2407"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
|
||||||
|
|
||||||
|
# for function calling
|
||||||
|
TOOLS = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type":
|
||||||
|
"string",
|
||||||
|
"description":
|
||||||
|
"The city to find the weather for, e.g. 'San Francisco'"
|
||||||
|
},
|
||||||
|
"state": {
|
||||||
|
"type":
|
||||||
|
"string",
|
||||||
|
"description":
|
||||||
|
"the two-letter abbreviation for the state that the city is"
|
||||||
|
" in, e.g. 'CA' which would mean 'California'"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The unit to fetch the temperature in",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city", "state", "unit"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
MSGS = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": ("Can you tell me what the temperate"
|
||||||
|
" will be in Dallas, in fahrenheit?")
|
||||||
|
}]
|
||||||
|
EXPECTED_FUNC_CALL = (
|
||||||
|
'[{"name": "get_current_weather", "arguments": '
|
||||||
|
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@ -81,3 +129,22 @@ def test_mistral_format(
|
|||||||
name_0="hf",
|
name_0="hf",
|
||||||
name_1="mistral",
|
name_1="mistral",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
|
@pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling
|
||||||
|
def test_mistral_function_calling(
|
||||||
|
vllm_runner,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
with vllm_runner(model,
|
||||||
|
dtype=dtype,
|
||||||
|
tokenizer_mode="mistral",
|
||||||
|
config_format="mistral",
|
||||||
|
load_format="mistral") as vllm_model:
|
||||||
|
outputs = vllm_model.model.chat(MSGS,
|
||||||
|
tools=TOOLS,
|
||||||
|
sampling_params=SAMPLING_PARAMS)
|
||||||
|
|
||||||
|
assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
|
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
|
||||||
|
overload)
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@ -357,6 +358,7 @@ class LLM:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
chat_template: Optional[str] = None,
|
chat_template: Optional[str] = None,
|
||||||
add_generation_prompt: bool = True,
|
add_generation_prompt: bool = True,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
"""
|
"""
|
||||||
Generate responses for a chat conversation.
|
Generate responses for a chat conversation.
|
||||||
@ -401,6 +403,7 @@ class LLM:
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
tools=tools,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = apply_hf_chat_template(
|
prompt = apply_hf_chat_template(
|
||||||
@ -408,6 +411,7 @@ class LLM:
|
|||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs: PromptInputs
|
inputs: PromptInputs
|
||||||
|
@ -123,7 +123,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
]
|
]
|
||||||
|
|
||||||
prompt: Union[str, List[int]]
|
prompt: Union[str, List[int]]
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
|
||||||
|
if is_mistral_tokenizer:
|
||||||
prompt = apply_mistral_chat_template(
|
prompt = apply_mistral_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
messages=request.messages,
|
messages=request.messages,
|
||||||
@ -159,10 +160,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"tool_choice = \"required\" is not supported!")
|
"tool_choice = \"required\" is not supported!")
|
||||||
|
|
||||||
# "auto" tools requires --enable-auto-tool-choice
|
if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
|
||||||
# and --tool-call-parser
|
|
||||||
if request.tool_choice == "auto" and not (
|
|
||||||
self.enable_auto_tools and self.tool_parser is not None):
|
self.enable_auto_tools and self.tool_parser is not None):
|
||||||
|
# for hf tokenizers, "auto" tools requires
|
||||||
|
# --enable-auto-tool-choice and --tool-call-parser
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"\"auto\" tool choice requires "
|
"\"auto\" tool choice requires "
|
||||||
"--enable-auto-tool-choice and --tool-call-parser to be set")
|
"--enable-auto-tool-choice and --tool-call-parser to be set")
|
||||||
|
@ -165,10 +165,9 @@ class MistralTokenizer:
|
|||||||
messages: List["ChatCompletionMessageParam"],
|
messages: List["ChatCompletionMessageParam"],
|
||||||
tools: Optional[Dict[str, Any]] = None,
|
tools: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs) -> List[int]:
|
**kwargs) -> List[int]:
|
||||||
assert tools is None, "`tools` are not yet supported."
|
|
||||||
|
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(messages=messages,
|
||||||
messages=messages) # type: ignore[type-var]
|
tools=tools) # type: ignore[type-var]
|
||||||
encoded = self.mistral.encode_chat_completion(request)
|
encoded = self.mistral.encode_chat_completion(request)
|
||||||
|
|
||||||
# encode-decode to get clean prompt
|
# encode-decode to get clean prompt
|
||||||
@ -176,7 +175,8 @@ class MistralTokenizer:
|
|||||||
|
|
||||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
if isinstance(self.tokenizer, Tekkenizer):
|
if isinstance(self.tokenizer, Tekkenizer):
|
||||||
return "".join(tokens)
|
return "".join(t for t in tokens
|
||||||
|
if t not in self.tokenizer._all_special_tokens)
|
||||||
else:
|
else:
|
||||||
return self.tokenizer.decode(tokens) # type: ignore[arg-type]
|
return self.tokenizer.decode(tokens) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user