[Bugfix] Fix chat template loading (#15143)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: chaunceyjiang <chaunceyjiang@gmail.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
038de04d7b
commit
cbcdf2c609
@ -107,8 +107,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
|||||||
# Call the function and get the result
|
# Call the function and get the result
|
||||||
result = apply_hf_chat_template(
|
result = apply_hf_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
trust_remote_code=True,
|
||||||
conversation=mock_request.messages,
|
conversation=mock_request.messages,
|
||||||
chat_template=mock_request.chat_template or template_content,
|
chat_template=mock_request.chat_template or template_content,
|
||||||
|
tools=None,
|
||||||
add_generation_prompt=mock_request.add_generation_prompt,
|
add_generation_prompt=mock_request.add_generation_prompt,
|
||||||
continue_final_message=mock_request.continue_final_message,
|
continue_final_message=mock_request.continue_final_message,
|
||||||
)
|
)
|
||||||
|
@ -87,7 +87,7 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI,
|
|||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=10, prompt_tokens=6299, total_tokens=6309)
|
completion_tokens=10, prompt_tokens=6287, total_tokens=6297)
|
||||||
|
|
||||||
message = choice.message
|
message = choice.message
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
@ -180,7 +180,7 @@ async def test_single_chat_session_video_base64encoded(
|
|||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=10, prompt_tokens=6299, total_tokens=6309)
|
completion_tokens=10, prompt_tokens=6287, total_tokens=6297)
|
||||||
|
|
||||||
message = choice.message
|
message = choice.message
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
|
@ -4,10 +4,13 @@ import warnings
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from packaging.version import Version
|
||||||
|
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||||
|
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
|
from vllm.entrypoints.chat_utils import (_resolve_hf_chat_template,
|
||||||
|
_try_extract_ast, load_chat_template,
|
||||||
parse_chat_messages,
|
parse_chat_messages,
|
||||||
parse_chat_messages_futures,
|
parse_chat_messages_futures,
|
||||||
resolve_chat_template_content_format)
|
resolve_chat_template_content_format)
|
||||||
@ -23,8 +26,10 @@ EXAMPLES_DIR = VLLM_PATH / "examples"
|
|||||||
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
|
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
|
||||||
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||||
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
||||||
|
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||||
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
||||||
|
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
@ -703,25 +708,27 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
|
|||||||
|
|
||||||
vllm_result = apply_hf_chat_template(
|
vllm_result = apply_hf_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
|
tools=None,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert hf_result == vllm_result
|
assert hf_result == vllm_result
|
||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model", "expected_format"),
|
"model",
|
||||||
[(PHI3V_MODEL_ID, "string"),
|
[
|
||||||
(QWEN2VL_MODEL_ID, "openai"),
|
QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str
|
||||||
(ULTRAVOX_MODEL_ID, "string"),
|
HERMES_MODEL_ID, # tokenizer.chat_template is of type dict
|
||||||
(MLLAMA_MODEL_ID, "openai"),
|
])
|
||||||
(LLAMA_GUARD_MODEL_ID, "openai")],
|
@pytest.mark.parametrize("use_tools", [True, False])
|
||||||
)
|
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
||||||
# yapf: enable
|
"""checks that chat_template is a dict type for HF models."""
|
||||||
def test_resolve_content_format_hf_defined(model, expected_format):
|
|
||||||
|
# Build the tokenizer group and grab the underlying tokenizer
|
||||||
tokenizer_group = TokenizerGroup(
|
tokenizer_group = TokenizerGroup(
|
||||||
model,
|
model,
|
||||||
enable_lora=False,
|
enable_lora=False,
|
||||||
@ -730,7 +737,56 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
|||||||
)
|
)
|
||||||
tokenizer = tokenizer_group.tokenizer
|
tokenizer = tokenizer_group.tokenizer
|
||||||
|
|
||||||
chat_template = tokenizer.chat_template
|
tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "dummy_function_name",
|
||||||
|
"description": "This is a dummy function",
|
||||||
|
"parameters": sample_json_schema
|
||||||
|
}
|
||||||
|
}] if use_tools else None
|
||||||
|
|
||||||
|
# Test detecting the tokenizer's chat_template
|
||||||
|
chat_template = _resolve_hf_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=None,
|
||||||
|
tools=tools,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
assert isinstance(chat_template, str)
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("model", "expected_format"),
|
||||||
|
[(PHI3V_MODEL_ID, "string"),
|
||||||
|
(QWEN2VL_MODEL_ID, "openai"),
|
||||||
|
(QWEN25VL_MODEL_ID, "openai"),
|
||||||
|
(ULTRAVOX_MODEL_ID, "string"),
|
||||||
|
(MLLAMA_MODEL_ID, "openai"),
|
||||||
|
(LLAMA_GUARD_MODEL_ID, "openai")],
|
||||||
|
)
|
||||||
|
# yapf: enable
|
||||||
|
def test_resolve_content_format_hf_defined(model, expected_format):
|
||||||
|
if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version(
|
||||||
|
"4.49.0"):
|
||||||
|
pytest.skip("Qwen2.5-VL requires transformers>=4.49.0")
|
||||||
|
|
||||||
|
tokenizer_group = TokenizerGroup(
|
||||||
|
model,
|
||||||
|
enable_lora=False,
|
||||||
|
max_num_seqs=5,
|
||||||
|
max_input_length=None,
|
||||||
|
)
|
||||||
|
tokenizer = tokenizer_group.tokenizer
|
||||||
|
|
||||||
|
# Test detecting the tokenizer's chat_template
|
||||||
|
chat_template = _resolve_hf_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=None,
|
||||||
|
tools=None,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
assert isinstance(chat_template, str)
|
assert isinstance(chat_template, str)
|
||||||
|
|
||||||
print("[TEXT]")
|
print("[TEXT]")
|
||||||
@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
|||||||
|
|
||||||
resolved_format = resolve_chat_template_content_format(
|
resolved_format = resolve_chat_template_content_format(
|
||||||
None, # Test detecting the tokenizer's chat_template
|
None, # Test detecting the tokenizer's chat_template
|
||||||
|
None,
|
||||||
"auto",
|
"auto",
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert resolved_format == expected_format
|
assert resolved_format == expected_format
|
||||||
@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format):
|
|||||||
|
|
||||||
resolved_format = resolve_chat_template_content_format(
|
resolved_format = resolve_chat_template_content_format(
|
||||||
chat_template,
|
chat_template,
|
||||||
|
None,
|
||||||
"auto",
|
"auto",
|
||||||
dummy_tokenizer,
|
dummy_tokenizer,
|
||||||
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert resolved_format == expected_format
|
assert resolved_format == expected_format
|
||||||
|
@ -39,7 +39,10 @@ def ensure_system_prompt(messages: list[dict[str, Any]],
|
|||||||
|
|
||||||
# universal args for all models go here. also good if you need to test locally
|
# universal args for all models go here. also good if you need to test locally
|
||||||
# and change type or KV cache quantization or something.
|
# and change type or KV cache quantization or something.
|
||||||
ARGS: list[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"]
|
ARGS: list[str] = [
|
||||||
|
"--enable-auto-tool-choice", "--max-model-len", "1024", "--max-num-seqs",
|
||||||
|
"256"
|
||||||
|
]
|
||||||
|
|
||||||
CONFIGS: dict[str, ServerConfig] = {
|
CONFIGS: dict[str, ServerConfig] = {
|
||||||
"hermes": {
|
"hermes": {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import codecs
|
|
||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
@ -30,7 +29,8 @@ from openai.types.chat.chat_completion_content_part_input_audio_param import (
|
|||||||
InputAudio)
|
InputAudio)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
# pydantic needs the TypedDict from typing_extensions
|
# pydantic needs the TypedDict from typing_extensions
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||||
|
ProcessorMixin)
|
||||||
from typing_extensions import Required, TypeAlias, TypedDict
|
from typing_extensions import Required, TypeAlias, TypedDict
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
@ -306,24 +306,63 @@ def _detect_content_format(
|
|||||||
return "openai"
|
return "openai"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_hf_chat_template(
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
chat_template: Optional[str],
|
||||||
|
tools: Optional[list[dict[str, Any]]],
|
||||||
|
*,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
) -> Optional[str]:
|
||||||
|
# 1st priority: The given chat template
|
||||||
|
if chat_template is not None:
|
||||||
|
return chat_template
|
||||||
|
|
||||||
|
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
|
||||||
|
if tools is None:
|
||||||
|
try:
|
||||||
|
processor = cached_get_processor(
|
||||||
|
tokenizer.name_or_path,
|
||||||
|
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||||
|
ProcessorMixin),
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
if isinstance(processor, ProcessorMixin) and \
|
||||||
|
processor.chat_template is not None:
|
||||||
|
return processor.chat_template
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to load AutoProcessor chat template for %s",
|
||||||
|
tokenizer.name_or_path, exc_info=True)
|
||||||
|
|
||||||
|
# 3rd priority: AutoTokenizer chat template
|
||||||
|
try:
|
||||||
|
return tokenizer.get_chat_template(chat_template, tools=tools)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to load AutoTokenizer chat template for %s",
|
||||||
|
tokenizer.name_or_path, exc_info=True)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _resolve_chat_template_content_format(
|
def _resolve_chat_template_content_format(
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
|
tools: Optional[list[dict[str, Any]]],
|
||||||
given_format: ChatTemplateContentFormatOption,
|
given_format: ChatTemplateContentFormatOption,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
|
*,
|
||||||
|
trust_remote_code: bool,
|
||||||
) -> _ChatTemplateContentFormat:
|
) -> _ChatTemplateContentFormat:
|
||||||
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||||
tokenizer_chat_template = tokenizer.chat_template
|
hf_chat_template = _resolve_hf_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=chat_template,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
tokenizer_chat_template = None
|
hf_chat_template = None
|
||||||
|
|
||||||
jinja_text: Optional[str]
|
jinja_text = (hf_chat_template if isinstance(hf_chat_template, str)
|
||||||
if isinstance(tokenizer_chat_template, str) and chat_template is None:
|
else load_chat_template(chat_template, is_literal=True))
|
||||||
jinja_text = tokenizer_chat_template
|
|
||||||
elif (isinstance(tokenizer_chat_template, dict)
|
|
||||||
and chat_template in tokenizer_chat_template):
|
|
||||||
jinja_text = tokenizer_chat_template[chat_template]
|
|
||||||
else:
|
|
||||||
jinja_text = load_chat_template(chat_template, is_literal=True)
|
|
||||||
|
|
||||||
detected_format = ("string" if jinja_text is None else
|
detected_format = ("string" if jinja_text is None else
|
||||||
_detect_content_format(jinja_text, default="string"))
|
_detect_content_format(jinja_text, default="string"))
|
||||||
@ -332,17 +371,11 @@ def _resolve_chat_template_content_format(
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def resolve_chat_template_content_format(
|
def _log_chat_template_content_format(
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
given_format: ChatTemplateContentFormatOption,
|
given_format: ChatTemplateContentFormatOption,
|
||||||
tokenizer: AnyTokenizer,
|
detected_format: ChatTemplateContentFormatOption,
|
||||||
) -> _ChatTemplateContentFormat:
|
):
|
||||||
detected_format = _resolve_chat_template_content_format(
|
|
||||||
chat_template,
|
|
||||||
given_format,
|
|
||||||
tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Detected the chat template content format to be '%s'. "
|
"Detected the chat template content format to be '%s'. "
|
||||||
"You can set `--chat-template-content-format` to override this.",
|
"You can set `--chat-template-content-format` to override this.",
|
||||||
@ -360,6 +393,29 @@ def resolve_chat_template_content_format(
|
|||||||
detected_format,
|
detected_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_chat_template_content_format(
|
||||||
|
chat_template: Optional[str],
|
||||||
|
tools: Optional[list[dict[str, Any]]],
|
||||||
|
given_format: ChatTemplateContentFormatOption,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
*,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
) -> _ChatTemplateContentFormat:
|
||||||
|
detected_format = _resolve_chat_template_content_format(
|
||||||
|
chat_template,
|
||||||
|
tools,
|
||||||
|
given_format,
|
||||||
|
tokenizer,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
_log_chat_template_content_format(
|
||||||
|
chat_template,
|
||||||
|
given_format=given_format,
|
||||||
|
detected_format=detected_format,
|
||||||
|
)
|
||||||
|
|
||||||
return detected_format
|
return detected_format
|
||||||
|
|
||||||
|
|
||||||
@ -711,7 +767,7 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
|||||||
f"{type(chat_template)} is not a valid chat template type")
|
f"{type(chat_template)} is not a valid chat template type")
|
||||||
|
|
||||||
|
|
||||||
def load_chat_template(
|
def _load_chat_template(
|
||||||
chat_template: Optional[Union[Path, str]],
|
chat_template: Optional[Union[Path, str]],
|
||||||
*,
|
*,
|
||||||
is_literal: bool = False,
|
is_literal: bool = False,
|
||||||
@ -724,7 +780,7 @@ def load_chat_template(
|
|||||||
raise TypeError("chat_template is expected to be read directly "
|
raise TypeError("chat_template is expected to be read directly "
|
||||||
"from its value")
|
"from its value")
|
||||||
|
|
||||||
return codecs.decode(chat_template, "unicode_escape")
|
return chat_template
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(chat_template) as f:
|
with open(chat_template) as f:
|
||||||
@ -742,7 +798,18 @@ def load_chat_template(
|
|||||||
|
|
||||||
# If opening a file fails, set chat template to be args to
|
# If opening a file fails, set chat template to be args to
|
||||||
# ensure we decode so our escape are interpreted correctly
|
# ensure we decode so our escape are interpreted correctly
|
||||||
return load_chat_template(chat_template, is_literal=True)
|
return _load_chat_template(chat_template, is_literal=True)
|
||||||
|
|
||||||
|
|
||||||
|
_cached_load_chat_template = lru_cache(_load_chat_template)
|
||||||
|
|
||||||
|
|
||||||
|
def load_chat_template(
|
||||||
|
chat_template: Optional[Union[Path, str]],
|
||||||
|
*,
|
||||||
|
is_literal: bool = False,
|
||||||
|
) -> Optional[str]:
|
||||||
|
return _cached_load_chat_template(chat_template, is_literal=is_literal)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Let user specify how to insert multimodal tokens into prompt
|
# TODO: Let user specify how to insert multimodal tokens into prompt
|
||||||
@ -1067,23 +1134,20 @@ def apply_hf_chat_template(
|
|||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
conversation: list[ConversationMessage],
|
conversation: list[ConversationMessage],
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
|
tools: Optional[list[dict[str, Any]]],
|
||||||
*,
|
*,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
tokenize: bool = False, # Different from HF's default
|
tokenize: bool = False, # Different from HF's default
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
if chat_template is None:
|
hf_chat_template = _resolve_hf_chat_template(
|
||||||
chat_template = tokenizer.chat_template
|
tokenizer,
|
||||||
|
chat_template=chat_template,
|
||||||
|
tools=tools,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
# FIXME: Temporary workaround for
|
if hf_chat_template is None:
|
||||||
# https://huggingface.co/mistral-community/pixtral-12b/discussions/31
|
|
||||||
if chat_template is None:
|
|
||||||
try:
|
|
||||||
processor = cached_get_processor(tokenizer.name_or_path)
|
|
||||||
chat_template = processor.chat_template
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if chat_template is None:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"As of transformers v4.44, default chat template is no longer "
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
"allowed, so you must provide a chat template if the tokenizer "
|
"allowed, so you must provide a chat template if the tokenizer "
|
||||||
@ -1091,7 +1155,8 @@ def apply_hf_chat_template(
|
|||||||
|
|
||||||
return tokenizer.apply_chat_template(
|
return tokenizer.apply_chat_template(
|
||||||
conversation=conversation, # type: ignore[arg-type]
|
conversation=conversation, # type: ignore[arg-type]
|
||||||
chat_template=chat_template,
|
tools=tools, # type: ignore[arg-type]
|
||||||
|
chat_template=hf_chat_template,
|
||||||
tokenize=tokenize,
|
tokenize=tokenize,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -1100,7 +1165,8 @@ def apply_hf_chat_template(
|
|||||||
def apply_mistral_chat_template(
|
def apply_mistral_chat_template(
|
||||||
tokenizer: MistralTokenizer,
|
tokenizer: MistralTokenizer,
|
||||||
messages: list[ChatCompletionMessageParam],
|
messages: list[ChatCompletionMessageParam],
|
||||||
chat_template: Optional[str] = None,
|
chat_template: Optional[str],
|
||||||
|
tools: Optional[list[dict[str, Any]]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
if chat_template is not None:
|
if chat_template is not None:
|
||||||
@ -1117,5 +1183,6 @@ def apply_mistral_chat_template(
|
|||||||
|
|
||||||
return tokenizer.apply_chat_template(
|
return tokenizer.apply_chat_template(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -690,8 +690,10 @@ class LLM:
|
|||||||
model_config = self.llm_engine.get_model_config()
|
model_config = self.llm_engine.get_model_config()
|
||||||
resolved_content_format = resolve_chat_template_content_format(
|
resolved_content_format = resolve_chat_template_content_format(
|
||||||
chat_template,
|
chat_template,
|
||||||
|
tools,
|
||||||
chat_template_content_format,
|
chat_template_content_format,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompts: list[Union[TokensPrompt, TextPrompt]] = []
|
prompts: list[Union[TokensPrompt, TextPrompt]] = []
|
||||||
@ -713,18 +715,19 @@ class LLM:
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
messages=msgs,
|
messages=msgs,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
|
tools=tools,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
continue_final_message=continue_final_message,
|
continue_final_message=continue_final_message,
|
||||||
tools=tools,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt_data = apply_hf_chat_template(
|
prompt_data = apply_hf_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
|
tools=tools,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
continue_final_message=continue_final_message,
|
continue_final_message=continue_final_message,
|
||||||
tools=tools,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt: Union[TokensPrompt, TextPrompt]
|
prompt: Union[TokensPrompt, TextPrompt]
|
||||||
|
@ -379,14 +379,18 @@ class OpenAIServing:
|
|||||||
add_special_tokens: bool = False,
|
add_special_tokens: bool = False,
|
||||||
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
|
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
|
||||||
list[TokensPrompt]]:
|
list[TokensPrompt]]:
|
||||||
|
model_config = self.model_config
|
||||||
|
|
||||||
resolved_content_format = resolve_chat_template_content_format(
|
resolved_content_format = resolve_chat_template_content_format(
|
||||||
chat_template,
|
chat_template,
|
||||||
|
tool_dicts,
|
||||||
chat_template_content_format,
|
chat_template_content_format,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
conversation, mm_data_future = parse_chat_messages_futures(
|
conversation, mm_data_future = parse_chat_messages_futures(
|
||||||
messages,
|
messages,
|
||||||
self.model_config,
|
model_config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
content_format=resolved_content_format,
|
content_format=resolved_content_format,
|
||||||
)
|
)
|
||||||
@ -410,6 +414,7 @@ class OpenAIServing:
|
|||||||
else:
|
else:
|
||||||
request_prompt = apply_hf_chat_template(
|
request_prompt = apply_hf_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
**_chat_template_kwargs,
|
**_chat_template_kwargs,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user