[Bugfix] Fix chat template loading (#15143)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: chaunceyjiang <chaunceyjiang@gmail.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
038de04d7b
commit
cbcdf2c609
@ -107,8 +107,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
# Call the function and get the result
|
||||
result = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
trust_remote_code=True,
|
||||
conversation=mock_request.messages,
|
||||
chat_template=mock_request.chat_template or template_content,
|
||||
tools=None,
|
||||
add_generation_prompt=mock_request.add_generation_prompt,
|
||||
continue_final_message=mock_request.continue_final_message,
|
||||
)
|
||||
|
@ -87,7 +87,7 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI,
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=6299, total_tokens=6309)
|
||||
completion_tokens=10, prompt_tokens=6287, total_tokens=6297)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@ -180,7 +180,7 @@ async def test_single_chat_session_video_base64encoded(
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=6299, total_tokens=6309)
|
||||
completion_tokens=10, prompt_tokens=6287, total_tokens=6297)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
|
@ -4,10 +4,13 @@ import warnings
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from packaging.version import Version
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
|
||||
from vllm.entrypoints.chat_utils import (_resolve_hf_chat_template,
|
||||
_try_extract_ast, load_chat_template,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format)
|
||||
@ -23,8 +26,10 @@ EXAMPLES_DIR = VLLM_PATH / "examples"
|
||||
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
|
||||
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
||||
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
||||
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@ -703,25 +708,27 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
|
||||
|
||||
vllm_result = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
conversation=conversation,
|
||||
chat_template=None,
|
||||
tools=None,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
assert hf_result == vllm_result
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("model", "expected_format"),
|
||||
[(PHI3V_MODEL_ID, "string"),
|
||||
(QWEN2VL_MODEL_ID, "openai"),
|
||||
(ULTRAVOX_MODEL_ID, "string"),
|
||||
(MLLAMA_MODEL_ID, "openai"),
|
||||
(LLAMA_GUARD_MODEL_ID, "openai")],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_resolve_content_format_hf_defined(model, expected_format):
|
||||
"model",
|
||||
[
|
||||
QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str
|
||||
HERMES_MODEL_ID, # tokenizer.chat_template is of type dict
|
||||
])
|
||||
@pytest.mark.parametrize("use_tools", [True, False])
|
||||
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
||||
"""checks that chat_template is a dict type for HF models."""
|
||||
|
||||
# Build the tokenizer group and grab the underlying tokenizer
|
||||
tokenizer_group = TokenizerGroup(
|
||||
model,
|
||||
enable_lora=False,
|
||||
@ -730,7 +737,56 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
||||
)
|
||||
tokenizer = tokenizer_group.tokenizer
|
||||
|
||||
chat_template = tokenizer.chat_template
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema
|
||||
}
|
||||
}] if use_tools else None
|
||||
|
||||
# Test detecting the tokenizer's chat_template
|
||||
chat_template = _resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
tools=tools,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
assert isinstance(chat_template, str)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("model", "expected_format"),
|
||||
[(PHI3V_MODEL_ID, "string"),
|
||||
(QWEN2VL_MODEL_ID, "openai"),
|
||||
(QWEN25VL_MODEL_ID, "openai"),
|
||||
(ULTRAVOX_MODEL_ID, "string"),
|
||||
(MLLAMA_MODEL_ID, "openai"),
|
||||
(LLAMA_GUARD_MODEL_ID, "openai")],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_resolve_content_format_hf_defined(model, expected_format):
|
||||
if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version(
|
||||
"4.49.0"):
|
||||
pytest.skip("Qwen2.5-VL requires transformers>=4.49.0")
|
||||
|
||||
tokenizer_group = TokenizerGroup(
|
||||
model,
|
||||
enable_lora=False,
|
||||
max_num_seqs=5,
|
||||
max_input_length=None,
|
||||
)
|
||||
tokenizer = tokenizer_group.tokenizer
|
||||
|
||||
# Test detecting the tokenizer's chat_template
|
||||
chat_template = _resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
tools=None,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
assert isinstance(chat_template, str)
|
||||
|
||||
print("[TEXT]")
|
||||
@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
||||
|
||||
resolved_format = resolve_chat_template_content_format(
|
||||
None, # Test detecting the tokenizer's chat_template
|
||||
None,
|
||||
"auto",
|
||||
tokenizer,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
assert resolved_format == expected_format
|
||||
@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format):
|
||||
|
||||
resolved_format = resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
None,
|
||||
"auto",
|
||||
dummy_tokenizer,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
assert resolved_format == expected_format
|
||||
|
@ -39,7 +39,10 @@ def ensure_system_prompt(messages: list[dict[str, Any]],
|
||||
|
||||
# universal args for all models go here. also good if you need to test locally
|
||||
# and change type or KV cache quantization or something.
|
||||
ARGS: list[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"]
|
||||
ARGS: list[str] = [
|
||||
"--enable-auto-tool-choice", "--max-model-len", "1024", "--max-num-seqs",
|
||||
"256"
|
||||
]
|
||||
|
||||
CONFIGS: dict[str, ServerConfig] = {
|
||||
"hermes": {
|
||||
|
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import codecs
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, deque
|
||||
@ -30,7 +29,8 @@ from openai.types.chat.chat_completion_content_part_input_audio_param import (
|
||||
InputAudio)
|
||||
# yapf: enable
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||
ProcessorMixin)
|
||||
from typing_extensions import Required, TypeAlias, TypedDict
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
@ -306,24 +306,63 @@ def _detect_content_format(
|
||||
return "openai"
|
||||
|
||||
|
||||
def _resolve_hf_chat_template(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
chat_template: Optional[str],
|
||||
tools: Optional[list[dict[str, Any]]],
|
||||
*,
|
||||
trust_remote_code: bool,
|
||||
) -> Optional[str]:
|
||||
# 1st priority: The given chat template
|
||||
if chat_template is not None:
|
||||
return chat_template
|
||||
|
||||
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
|
||||
if tools is None:
|
||||
try:
|
||||
processor = cached_get_processor(
|
||||
tokenizer.name_or_path,
|
||||
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||
ProcessorMixin),
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if isinstance(processor, ProcessorMixin) and \
|
||||
processor.chat_template is not None:
|
||||
return processor.chat_template
|
||||
except Exception:
|
||||
logger.debug("Failed to load AutoProcessor chat template for %s",
|
||||
tokenizer.name_or_path, exc_info=True)
|
||||
|
||||
# 3rd priority: AutoTokenizer chat template
|
||||
try:
|
||||
return tokenizer.get_chat_template(chat_template, tools=tools)
|
||||
except Exception:
|
||||
logger.debug("Failed to load AutoTokenizer chat template for %s",
|
||||
tokenizer.name_or_path, exc_info=True)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_chat_template_content_format(
|
||||
chat_template: Optional[str],
|
||||
tools: Optional[list[dict[str, Any]]],
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
tokenizer: AnyTokenizer,
|
||||
*,
|
||||
trust_remote_code: bool,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||
tokenizer_chat_template = tokenizer.chat_template
|
||||
hf_chat_template = _resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tools=tools,
|
||||
)
|
||||
else:
|
||||
tokenizer_chat_template = None
|
||||
hf_chat_template = None
|
||||
|
||||
jinja_text: Optional[str]
|
||||
if isinstance(tokenizer_chat_template, str) and chat_template is None:
|
||||
jinja_text = tokenizer_chat_template
|
||||
elif (isinstance(tokenizer_chat_template, dict)
|
||||
and chat_template in tokenizer_chat_template):
|
||||
jinja_text = tokenizer_chat_template[chat_template]
|
||||
else:
|
||||
jinja_text = load_chat_template(chat_template, is_literal=True)
|
||||
jinja_text = (hf_chat_template if isinstance(hf_chat_template, str)
|
||||
else load_chat_template(chat_template, is_literal=True))
|
||||
|
||||
detected_format = ("string" if jinja_text is None else
|
||||
_detect_content_format(jinja_text, default="string"))
|
||||
@ -332,17 +371,11 @@ def _resolve_chat_template_content_format(
|
||||
|
||||
|
||||
@lru_cache
|
||||
def resolve_chat_template_content_format(
|
||||
def _log_chat_template_content_format(
|
||||
chat_template: Optional[str],
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
detected_format = _resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
given_format,
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
detected_format: ChatTemplateContentFormatOption,
|
||||
):
|
||||
logger.info(
|
||||
"Detected the chat template content format to be '%s'. "
|
||||
"You can set `--chat-template-content-format` to override this.",
|
||||
@ -360,6 +393,29 @@ def resolve_chat_template_content_format(
|
||||
detected_format,
|
||||
)
|
||||
|
||||
|
||||
def resolve_chat_template_content_format(
|
||||
chat_template: Optional[str],
|
||||
tools: Optional[list[dict[str, Any]]],
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
tokenizer: AnyTokenizer,
|
||||
*,
|
||||
trust_remote_code: bool = False,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
detected_format = _resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tools,
|
||||
given_format,
|
||||
tokenizer,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
_log_chat_template_content_format(
|
||||
chat_template,
|
||||
given_format=given_format,
|
||||
detected_format=detected_format,
|
||||
)
|
||||
|
||||
return detected_format
|
||||
|
||||
|
||||
@ -711,7 +767,7 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
||||
f"{type(chat_template)} is not a valid chat template type")
|
||||
|
||||
|
||||
def load_chat_template(
|
||||
def _load_chat_template(
|
||||
chat_template: Optional[Union[Path, str]],
|
||||
*,
|
||||
is_literal: bool = False,
|
||||
@ -724,7 +780,7 @@ def load_chat_template(
|
||||
raise TypeError("chat_template is expected to be read directly "
|
||||
"from its value")
|
||||
|
||||
return codecs.decode(chat_template, "unicode_escape")
|
||||
return chat_template
|
||||
|
||||
try:
|
||||
with open(chat_template) as f:
|
||||
@ -742,7 +798,18 @@ def load_chat_template(
|
||||
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
return load_chat_template(chat_template, is_literal=True)
|
||||
return _load_chat_template(chat_template, is_literal=True)
|
||||
|
||||
|
||||
_cached_load_chat_template = lru_cache(_load_chat_template)
|
||||
|
||||
|
||||
def load_chat_template(
|
||||
chat_template: Optional[Union[Path, str]],
|
||||
*,
|
||||
is_literal: bool = False,
|
||||
) -> Optional[str]:
|
||||
return _cached_load_chat_template(chat_template, is_literal=is_literal)
|
||||
|
||||
|
||||
# TODO: Let user specify how to insert multimodal tokens into prompt
|
||||
@ -1067,23 +1134,20 @@ def apply_hf_chat_template(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
conversation: list[ConversationMessage],
|
||||
chat_template: Optional[str],
|
||||
tools: Optional[list[dict[str, Any]]],
|
||||
*,
|
||||
trust_remote_code: bool = False,
|
||||
tokenize: bool = False, # Different from HF's default
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if chat_template is None:
|
||||
chat_template = tokenizer.chat_template
|
||||
hf_chat_template = _resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
tools=tools,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
# FIXME: Temporary workaround for
|
||||
# https://huggingface.co/mistral-community/pixtral-12b/discussions/31
|
||||
if chat_template is None:
|
||||
try:
|
||||
processor = cached_get_processor(tokenizer.name_or_path)
|
||||
chat_template = processor.chat_template
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if chat_template is None:
|
||||
if hf_chat_template is None:
|
||||
raise ValueError(
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
@ -1091,7 +1155,8 @@ def apply_hf_chat_template(
|
||||
|
||||
return tokenizer.apply_chat_template(
|
||||
conversation=conversation, # type: ignore[arg-type]
|
||||
chat_template=chat_template,
|
||||
tools=tools, # type: ignore[arg-type]
|
||||
chat_template=hf_chat_template,
|
||||
tokenize=tokenize,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1100,7 +1165,8 @@ def apply_hf_chat_template(
|
||||
def apply_mistral_chat_template(
|
||||
tokenizer: MistralTokenizer,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template: Optional[str] = None,
|
||||
chat_template: Optional[str],
|
||||
tools: Optional[list[dict[str, Any]]],
|
||||
**kwargs: Any,
|
||||
) -> list[int]:
|
||||
if chat_template is not None:
|
||||
@ -1117,5 +1183,6 @@ def apply_mistral_chat_template(
|
||||
|
||||
return tokenizer.apply_chat_template(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -690,8 +690,10 @@ class LLM:
|
||||
model_config = self.llm_engine.get_model_config()
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tools,
|
||||
chat_template_content_format,
|
||||
tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
prompts: list[Union[TokensPrompt, TextPrompt]] = []
|
||||
@ -713,18 +715,19 @@ class LLM:
|
||||
tokenizer,
|
||||
messages=msgs,
|
||||
chat_template=chat_template,
|
||||
tools=tools,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
)
|
||||
else:
|
||||
prompt_data = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
conversation=conversation,
|
||||
chat_template=chat_template,
|
||||
tools=tools,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
prompt: Union[TokensPrompt, TextPrompt]
|
||||
|
@ -379,14 +379,18 @@ class OpenAIServing:
|
||||
add_special_tokens: bool = False,
|
||||
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
|
||||
list[TokensPrompt]]:
|
||||
model_config = self.model_config
|
||||
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tool_dicts,
|
||||
chat_template_content_format,
|
||||
tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
conversation, mm_data_future = parse_chat_messages_futures(
|
||||
messages,
|
||||
self.model_config,
|
||||
model_config,
|
||||
tokenizer,
|
||||
content_format=resolved_content_format,
|
||||
)
|
||||
@ -410,6 +414,7 @@ class OpenAIServing:
|
||||
else:
|
||||
request_prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
conversation=conversation,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user