[Frontend] Clean up type annotations for mistral tokenizer (#8314)

This commit is contained in:
Cyrus Leung 2024-09-11 00:49:11 +08:00 committed by GitHub
parent 6234385f4a
commit 8c054b7a62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 115 additions and 60 deletions

View File

@ -1,6 +1,7 @@
import pytest
from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
load_chat_template)
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer
@ -87,7 +88,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt=add_generation_prompt)
# Call the function and get the result
result = apply_chat_template(
result = apply_hf_chat_template(
tokenizer,
conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content,

View File

@ -23,6 +23,7 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
@ -31,7 +32,7 @@ from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image,
get_and_parse_audio, get_and_parse_image)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
logger = init_logger(__name__)
@ -379,6 +380,9 @@ def _parse_chat_message_content_parts(
audio_url = _AudioParser(part)["audio_url"]
mm_parser.parse_audio(audio_url["url"])
elif part_type == "refusal":
text = _RefusalParser(part)["refusal"]
texts.append(text)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
@ -433,6 +437,21 @@ def _parse_chat_message_content(
return result
def _postprocess_messages(messages: List[ConversationMessage]) -> None:
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in messages:
if (message["role"] == "assistant" and "tool_calls" in message
and isinstance(message["tool_calls"], list)):
for item in message["tool_calls"]:
item["function"]["arguments"] = json.loads(
item["function"]["arguments"])
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
@ -446,6 +465,8 @@ def parse_chat_messages(
conversation.extend(sub_messages)
_postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data()
@ -462,41 +483,41 @@ def parse_chat_messages_futures(
conversation.extend(sub_messages)
_postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data()
def apply_chat_template(
tokenizer: AnyTokenizer,
def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: List[ConversationMessage],
chat_template: Optional[str],
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> Union[str, List[int]]:
) -> str:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in conversation:
if (message["role"] == "assistant" and "tool_calls" in message
and isinstance(message["tool_calls"], list)):
for i in range(len(message["tool_calls"])):
args: str = message["tool_calls"][i]["function"]["arguments"]
parsed_args: Dict = json.loads(args)
message["tool_calls"][i]["function"]["arguments"] = parsed_args
prompt = tokenizer.apply_chat_template(
conversation=conversation,
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
chat_template=chat_template,
tokenize=tokenize,
**kwargs,
)
return prompt
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str],
**kwargs: Any,
) -> List[int]:
return tokenizer.apply_chat_template(
messages=messages,
chat_template=chat_template,
**kwargs,
)

View File

@ -6,7 +6,8 @@ from tqdm import tqdm
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_chat_template,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
@ -19,7 +20,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
@ -393,9 +394,18 @@ class LLM:
conversation, mm_data = parse_chat_messages(messages, model_config,
tokenizer)
prompt = apply_chat_template(
prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
tokenizer,
conversation,
messages=messages,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
)

View File

@ -11,7 +11,8 @@ from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_chat_template,
apply_hf_chat_template,
apply_mistral_chat_template,
load_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger
@ -35,7 +36,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import iterate_with_cancellation, random_uuid
logger = init_logger(__name__)
@ -121,7 +122,19 @@ class OpenAIServingChat(OpenAIServing):
tool.model_dump() for tool in request.tools
]
prompt = apply_chat_template(
prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
tokenizer,
messages=request.messages,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=request.chat_template or self.chat_template,
@ -307,11 +320,10 @@ class OpenAIServingChat(OpenAIServing):
# Send response to echo the input portion of the
# last message
if request.echo:
last_msg_content: Optional[str] = ""
if conversation and conversation[-1].get(
"content") and conversation[-1].get(
"role") == role:
last_msg_content = conversation[-1]["content"]
last_msg_content: str = ""
if conversation and "content" in conversation[
-1] and conversation[-1].get("role") == role:
last_msg_content = conversation[-1]["content"] or ""
if last_msg_content:
for i in range(num_choices):
@ -659,8 +671,8 @@ class OpenAIServingChat(OpenAIServing):
if request.echo:
last_msg_content = ""
if conversation and conversation[-1].get(
"content") and conversation[-1].get("role") == role:
if conversation and "content" in conversation[-1] and conversation[
-1].get("role") == role:
last_msg_content = conversation[-1]["content"] or ""
for choice in choices:

View File

@ -2,7 +2,8 @@ from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (apply_chat_template,
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
apply_mistral_chat_template,
load_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger
@ -18,6 +19,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ -66,6 +68,7 @@ class OpenAIServingTokenization(OpenAIServing):
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
prompt: Union[str, List[int]]
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config
@ -77,7 +80,15 @@ class OpenAIServingTokenization(OpenAIServing):
logger.warning(
"Multi-modal inputs are ignored during tokenization")
prompt = apply_chat_template(
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
tokenizer,
messages=request.messages,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=self.chat_template,

View File

@ -16,7 +16,7 @@ from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
Tekkenizer)
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@dataclass
@ -128,13 +128,13 @@ class MistralTokenizer:
return self.tokenizer.encode(prompt, bos=True, eos=False)
def apply_chat_template(self,
conversation: List["ConversationMessage"],
messages: List["ChatCompletionMessageParam"],
tools: Optional[Dict[str, Any]] = None,
**kwargs) -> List[int]:
assert tools is None, "`tools` are not yet supported."
request = ChatCompletionRequest(
messages=conversation) # type: ignore[type-var]
messages=messages) # type: ignore[type-var]
encoded = self.mistral.encode_chat_completion(request)
# encode-decode to get clean prompt