[Frontend] Clean up type annotations for mistral tokenizer (#8314)
This commit is contained in:
parent
6234385f4a
commit
8c054b7a62
@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
|
||||
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
||||
load_chat_template)
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
@ -87,7 +88,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
add_generation_prompt=add_generation_prompt)
|
||||
|
||||
# Call the function and get the result
|
||||
result = apply_chat_template(
|
||||
result = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=mock_request.messages,
|
||||
chat_template=mock_request.chat_template or template_content,
|
||||
|
@ -23,6 +23,7 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
|
||||
# yapf: enable
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from pydantic import ConfigDict
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from typing_extensions import Required, TypeAlias, TypedDict
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
@ -31,7 +32,7 @@ from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
||||
async_get_and_parse_image,
|
||||
get_and_parse_audio, get_and_parse_image)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -379,6 +380,9 @@ def _parse_chat_message_content_parts(
|
||||
audio_url = _AudioParser(part)["audio_url"]
|
||||
|
||||
mm_parser.parse_audio(audio_url["url"])
|
||||
elif part_type == "refusal":
|
||||
text = _RefusalParser(part)["refusal"]
|
||||
texts.append(text)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
@ -433,6 +437,21 @@ def _parse_chat_message_content(
|
||||
return result
|
||||
|
||||
|
||||
def _postprocess_messages(messages: List[ConversationMessage]) -> None:
|
||||
# per the Transformers docs & maintainers, tool call arguments in
|
||||
# assistant-role messages with tool_calls need to be dicts not JSON str -
|
||||
# this is how tool-use chat templates will expect them moving forwards
|
||||
# so, for messages that have tool_calls, parse the string (which we get
|
||||
# from openAI format) to dict
|
||||
for message in messages:
|
||||
if (message["role"] == "assistant" and "tool_calls" in message
|
||||
and isinstance(message["tool_calls"], list)):
|
||||
|
||||
for item in message["tool_calls"]:
|
||||
item["function"]["arguments"] = json.loads(
|
||||
item["function"]["arguments"])
|
||||
|
||||
|
||||
def parse_chat_messages(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
@ -446,6 +465,8 @@ def parse_chat_messages(
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
_postprocess_messages(conversation)
|
||||
|
||||
return conversation, mm_tracker.all_mm_data()
|
||||
|
||||
|
||||
@ -462,41 +483,41 @@ def parse_chat_messages_futures(
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
_postprocess_messages(conversation)
|
||||
|
||||
return conversation, mm_tracker.all_mm_data()
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: AnyTokenizer,
|
||||
def apply_hf_chat_template(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
conversation: List[ConversationMessage],
|
||||
chat_template: Optional[str],
|
||||
*,
|
||||
tokenize: bool = False, # Different from HF's default
|
||||
**kwargs: Any,
|
||||
) -> Union[str, List[int]]:
|
||||
) -> str:
|
||||
if chat_template is None and tokenizer.chat_template is None:
|
||||
raise ValueError(
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one.")
|
||||
|
||||
# per the Transformers docs & maintainers, tool call arguments in
|
||||
# assistant-role messages with tool_calls need to be dicts not JSON str -
|
||||
# this is how tool-use chat templates will expect them moving forwards
|
||||
# so, for messages that have tool_calls, parse the string (which we get
|
||||
# from openAI format) to dict
|
||||
for message in conversation:
|
||||
if (message["role"] == "assistant" and "tool_calls" in message
|
||||
and isinstance(message["tool_calls"], list)):
|
||||
|
||||
for i in range(len(message["tool_calls"])):
|
||||
args: str = message["tool_calls"][i]["function"]["arguments"]
|
||||
parsed_args: Dict = json.loads(args)
|
||||
message["tool_calls"][i]["function"]["arguments"] = parsed_args
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
return tokenizer.apply_chat_template(
|
||||
conversation=conversation, # type: ignore[arg-type]
|
||||
chat_template=chat_template,
|
||||
tokenize=tokenize,
|
||||
**kwargs,
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
def apply_mistral_chat_template(
|
||||
tokenizer: MistralTokenizer,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
chat_template: Optional[str],
|
||||
**kwargs: Any,
|
||||
) -> List[int]:
|
||||
return tokenizer.apply_chat_template(
|
||||
messages=messages,
|
||||
chat_template=chat_template,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -6,7 +6,8 @@ from tqdm import tqdm
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_chat_template,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages)
|
||||
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
@ -19,7 +20,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -393,12 +394,21 @@ class LLM:
|
||||
conversation, mm_data = parse_chat_messages(messages, model_config,
|
||||
tokenizer)
|
||||
|
||||
prompt = apply_chat_template(
|
||||
tokenizer,
|
||||
conversation,
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
prompt: Union[str, List[int]]
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=messages,
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
else:
|
||||
prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
|
||||
inputs: PromptInputs
|
||||
if is_list_of(prompt, int):
|
||||
|
@ -11,7 +11,8 @@ from fastapi import Request
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
apply_chat_template,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
load_chat_template,
|
||||
parse_chat_messages_futures)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
@ -35,7 +36,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import iterate_with_cancellation, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -121,15 +122,27 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool.model_dump() for tool in request.tools
|
||||
]
|
||||
|
||||
prompt = apply_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
tools=tool_dicts,
|
||||
documents=request.documents,
|
||||
**(request.chat_template_kwargs or {}),
|
||||
)
|
||||
prompt: Union[str, List[int]]
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
tools=tool_dicts,
|
||||
documents=request.documents,
|
||||
**(request.chat_template_kwargs or {}),
|
||||
)
|
||||
else:
|
||||
prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
tools=tool_dicts,
|
||||
documents=request.documents,
|
||||
**(request.chat_template_kwargs or {}),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error in applying chat template from request: %s", e)
|
||||
return self.create_error_response(str(e))
|
||||
@ -307,11 +320,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# Send response to echo the input portion of the
|
||||
# last message
|
||||
if request.echo:
|
||||
last_msg_content: Optional[str] = ""
|
||||
if conversation and conversation[-1].get(
|
||||
"content") and conversation[-1].get(
|
||||
"role") == role:
|
||||
last_msg_content = conversation[-1]["content"]
|
||||
last_msg_content: str = ""
|
||||
if conversation and "content" in conversation[
|
||||
-1] and conversation[-1].get("role") == role:
|
||||
last_msg_content = conversation[-1]["content"] or ""
|
||||
|
||||
if last_msg_content:
|
||||
for i in range(num_choices):
|
||||
@ -659,8 +671,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
if request.echo:
|
||||
last_msg_content = ""
|
||||
if conversation and conversation[-1].get(
|
||||
"content") and conversation[-1].get("role") == role:
|
||||
if conversation and "content" in conversation[-1] and conversation[
|
||||
-1].get("role") == role:
|
||||
last_msg_content = conversation[-1]["content"] or ""
|
||||
|
||||
for choice in choices:
|
||||
|
@ -2,7 +2,8 @@ from typing import List, Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.entrypoints.chat_utils import (apply_chat_template,
|
||||
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
load_chat_template,
|
||||
parse_chat_messages_futures)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
@ -18,6 +19,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -66,6 +68,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
|
||||
|
||||
prompt: Union[str, List[int]]
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
model_config = self.model_config
|
||||
|
||||
@ -77,12 +80,20 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
logger.warning(
|
||||
"Multi-modal inputs are ignored during tokenization")
|
||||
|
||||
prompt = apply_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
)
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=request.messages,
|
||||
chat_template=self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
)
|
||||
else:
|
||||
prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
)
|
||||
else:
|
||||
prompt = request.prompt
|
||||
|
||||
|
@ -16,7 +16,7 @@ from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
|
||||
Tekkenizer)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.chat_utils import ConversationMessage
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -122,19 +122,19 @@ class MistralTokenizer:
|
||||
return []
|
||||
|
||||
def encode(self, prompt: str) -> List[int]:
|
||||
# `encode ` should only be used for prompt completion
|
||||
# `encode` should only be used for prompt completion
|
||||
# it should never be used for chat_completion.
|
||||
# For chat completion use `apply_chat_template`
|
||||
return self.tokenizer.encode(prompt, bos=True, eos=False)
|
||||
|
||||
def apply_chat_template(self,
|
||||
conversation: List["ConversationMessage"],
|
||||
messages: List["ChatCompletionMessageParam"],
|
||||
tools: Optional[Dict[str, Any]] = None,
|
||||
**kwargs) -> List[int]:
|
||||
assert tools is None, "`tools` are not yet supported."
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
messages=conversation) # type: ignore[type-var]
|
||||
messages=messages) # type: ignore[type-var]
|
||||
encoded = self.mistral.encode_chat_completion(request)
|
||||
|
||||
# encode-decode to get clean prompt
|
||||
|
Loading…
x
Reference in New Issue
Block a user