[Bugfix] Fix chat template loading (#15143)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: chaunceyjiang <chaunceyjiang@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung 2025-03-24 21:50:09 +08:00 committed by GitHub
parent 038de04d7b
commit cbcdf2c609
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 196 additions and 56 deletions

View File

@ -107,8 +107,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Call the function and get the result
result = apply_hf_chat_template(
tokenizer,
trust_remote_code=True,
conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content,
tools=None,
add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message,
)

View File

@ -87,7 +87,7 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI,
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=6299, total_tokens=6309)
completion_tokens=10, prompt_tokens=6287, total_tokens=6297)
message = choice.message
message = chat_completion.choices[0].message
@ -180,7 +180,7 @@ async def test_single_chat_session_video_base64encoded(
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=6299, total_tokens=6309)
completion_tokens=10, prompt_tokens=6287, total_tokens=6297)
message = choice.message
message = chat_completion.choices[0].message

View File

@ -4,10 +4,13 @@ import warnings
from typing import Optional
import pytest
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
from vllm.entrypoints.chat_utils import (_resolve_hf_chat_template,
_try_extract_ast, load_chat_template,
parse_chat_messages,
parse_chat_messages_futures,
resolve_chat_template_content_format)
@ -23,8 +26,10 @@ EXAMPLES_DIR = VLLM_PATH / "examples"
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
@pytest.fixture(scope="function")
@ -703,25 +708,27 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
vllm_result = apply_hf_chat_template(
tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation,
chat_template=None,
tools=None,
add_generation_prompt=True,
)
assert hf_result == vllm_result
# yapf: disable
@pytest.mark.parametrize(
("model", "expected_format"),
[(PHI3V_MODEL_ID, "string"),
(QWEN2VL_MODEL_ID, "openai"),
(ULTRAVOX_MODEL_ID, "string"),
(MLLAMA_MODEL_ID, "openai"),
(LLAMA_GUARD_MODEL_ID, "openai")],
)
# yapf: enable
def test_resolve_content_format_hf_defined(model, expected_format):
"model",
[
QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str
HERMES_MODEL_ID, # tokenizer.chat_template is of type dict
])
@pytest.mark.parametrize("use_tools", [True, False])
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
"""checks that chat_template is a dict type for HF models."""
# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup(
model,
enable_lora=False,
@ -730,7 +737,56 @@ def test_resolve_content_format_hf_defined(model, expected_format):
)
tokenizer = tokenizer_group.tokenizer
chat_template = tokenizer.chat_template
tools = [{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema
}
}] if use_tools else None
# Test detecting the tokenizer's chat_template
chat_template = _resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=tools,
trust_remote_code=True,
)
assert isinstance(chat_template, str)
# yapf: disable
@pytest.mark.parametrize(
("model", "expected_format"),
[(PHI3V_MODEL_ID, "string"),
(QWEN2VL_MODEL_ID, "openai"),
(QWEN25VL_MODEL_ID, "openai"),
(ULTRAVOX_MODEL_ID, "string"),
(MLLAMA_MODEL_ID, "openai"),
(LLAMA_GUARD_MODEL_ID, "openai")],
)
# yapf: enable
def test_resolve_content_format_hf_defined(model, expected_format):
if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version(
"4.49.0"):
pytest.skip("Qwen2.5-VL requires transformers>=4.49.0")
tokenizer_group = TokenizerGroup(
model,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
tokenizer = tokenizer_group.tokenizer
# Test detecting the tokenizer's chat_template
chat_template = _resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=None,
trust_remote_code=True,
)
assert isinstance(chat_template, str)
print("[TEXT]")
@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):
resolved_format = resolve_chat_template_content_format(
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
trust_remote_code=True,
)
assert resolved_format == expected_format
@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format):
resolved_format = resolve_chat_template_content_format(
chat_template,
None,
"auto",
dummy_tokenizer,
trust_remote_code=True,
)
assert resolved_format == expected_format

View File

@ -39,7 +39,10 @@ def ensure_system_prompt(messages: list[dict[str, Any]],
# universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something.
ARGS: list[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"]
ARGS: list[str] = [
"--enable-auto-tool-choice", "--max-model-len", "1024", "--max-num-seqs",
"256"
]
CONFIGS: dict[str, ServerConfig] = {
"hermes": {

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import codecs
import json
from abc import ABC, abstractmethod
from collections import defaultdict, deque
@ -30,7 +29,8 @@ from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin)
from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
@ -306,24 +306,63 @@ def _detect_content_format(
return "openai"
def _resolve_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*,
trust_remote_code: bool,
) -> Optional[str]:
# 1st priority: The given chat template
if chat_template is not None:
return chat_template
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
if tools is None:
try:
processor = cached_get_processor(
tokenizer.name_or_path,
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin),
trust_remote_code=trust_remote_code,
)
if isinstance(processor, ProcessorMixin) and \
processor.chat_template is not None:
return processor.chat_template
except Exception:
logger.debug("Failed to load AutoProcessor chat template for %s",
tokenizer.name_or_path, exc_info=True)
# 3rd priority: AutoTokenizer chat template
try:
return tokenizer.get_chat_template(chat_template, tools=tools)
except Exception:
logger.debug("Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path, exc_info=True)
return None
def _resolve_chat_template_content_format(
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
*,
trust_remote_code: bool,
) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
tokenizer_chat_template = tokenizer.chat_template
hf_chat_template = _resolve_hf_chat_template(
tokenizer,
chat_template=chat_template,
trust_remote_code=trust_remote_code,
tools=tools,
)
else:
tokenizer_chat_template = None
hf_chat_template = None
jinja_text: Optional[str]
if isinstance(tokenizer_chat_template, str) and chat_template is None:
jinja_text = tokenizer_chat_template
elif (isinstance(tokenizer_chat_template, dict)
and chat_template in tokenizer_chat_template):
jinja_text = tokenizer_chat_template[chat_template]
else:
jinja_text = load_chat_template(chat_template, is_literal=True)
jinja_text = (hf_chat_template if isinstance(hf_chat_template, str)
else load_chat_template(chat_template, is_literal=True))
detected_format = ("string" if jinja_text is None else
_detect_content_format(jinja_text, default="string"))
@ -332,17 +371,11 @@ def _resolve_chat_template_content_format(
@lru_cache
def resolve_chat_template_content_format(
def _log_chat_template_content_format(
chat_template: Optional[str],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
) -> _ChatTemplateContentFormat:
detected_format = _resolve_chat_template_content_format(
chat_template,
given_format,
tokenizer,
)
detected_format: ChatTemplateContentFormatOption,
):
logger.info(
"Detected the chat template content format to be '%s'. "
"You can set `--chat-template-content-format` to override this.",
@ -360,6 +393,29 @@ def resolve_chat_template_content_format(
detected_format,
)
def resolve_chat_template_content_format(
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
*,
trust_remote_code: bool = False,
) -> _ChatTemplateContentFormat:
detected_format = _resolve_chat_template_content_format(
chat_template,
tools,
given_format,
tokenizer,
trust_remote_code=trust_remote_code,
)
_log_chat_template_content_format(
chat_template,
given_format=given_format,
detected_format=detected_format,
)
return detected_format
@ -711,7 +767,7 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
f"{type(chat_template)} is not a valid chat template type")
def load_chat_template(
def _load_chat_template(
chat_template: Optional[Union[Path, str]],
*,
is_literal: bool = False,
@ -724,7 +780,7 @@ def load_chat_template(
raise TypeError("chat_template is expected to be read directly "
"from its value")
return codecs.decode(chat_template, "unicode_escape")
return chat_template
try:
with open(chat_template) as f:
@ -742,7 +798,18 @@ def load_chat_template(
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
return load_chat_template(chat_template, is_literal=True)
return _load_chat_template(chat_template, is_literal=True)
_cached_load_chat_template = lru_cache(_load_chat_template)
def load_chat_template(
chat_template: Optional[Union[Path, str]],
*,
is_literal: bool = False,
) -> Optional[str]:
return _cached_load_chat_template(chat_template, is_literal=is_literal)
# TODO: Let user specify how to insert multimodal tokens into prompt
@ -1067,23 +1134,20 @@ def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*,
trust_remote_code: bool = False,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
if chat_template is None:
chat_template = tokenizer.chat_template
hf_chat_template = _resolve_hf_chat_template(
tokenizer,
chat_template=chat_template,
tools=tools,
trust_remote_code=trust_remote_code,
)
# FIXME: Temporary workaround for
# https://huggingface.co/mistral-community/pixtral-12b/discussions/31
if chat_template is None:
try:
processor = cached_get_processor(tokenizer.name_or_path)
chat_template = processor.chat_template
except Exception:
pass
if chat_template is None:
if hf_chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
@ -1091,7 +1155,8 @@ def apply_hf_chat_template(
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
chat_template=chat_template,
tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template,
tokenize=tokenize,
**kwargs,
)
@ -1100,7 +1165,8 @@ def apply_hf_chat_template(
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: list[ChatCompletionMessageParam],
chat_template: Optional[str] = None,
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
**kwargs: Any,
) -> list[int]:
if chat_template is not None:
@ -1117,5 +1183,6 @@ def apply_mistral_chat_template(
return tokenizer.apply_chat_template(
messages=messages,
tools=tools,
**kwargs,
)

View File

@ -690,8 +690,10 @@ class LLM:
model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format(
chat_template,
tools,
chat_template_content_format,
tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
prompts: list[Union[TokensPrompt, TextPrompt]] = []
@ -713,18 +715,19 @@ class LLM:
tokenizer,
messages=msgs,
chat_template=chat_template,
tools=tools,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
)
else:
prompt_data = apply_hf_chat_template(
tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation,
chat_template=chat_template,
tools=tools,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
)
prompt: Union[TokensPrompt, TextPrompt]

View File

@ -379,14 +379,18 @@ class OpenAIServing:
add_special_tokens: bool = False,
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
list[TokensPrompt]]:
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format(
chat_template,
tool_dicts,
chat_template_content_format,
tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
conversation, mm_data_future = parse_chat_messages_futures(
messages,
self.model_config,
model_config,
tokenizer,
content_format=resolved_content_format,
)
@ -410,6 +414,7 @@ class OpenAIServing:
else:
request_prompt = apply_hf_chat_template(
tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation,
**_chat_template_kwargs,
)