vllm/tests/entrypoints/openai/test_chat_template.py
Russell Bryant e489ad7a21
[Misc] Add SPDX-License-Identifier headers to python source files (#12628)
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**

commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:18:24 2025 -0500

    Add SPDX license headers to python source files
    
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
    also be easily used by tools to help manage license compliance.
    
The Linux Foundation runs license scans against the codebase to help
ensure
    we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
    
    More information can be found on the SPDX site:
    
    - https://spdx.dev/learn/handling-license-info/
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:36:32 2025 -0500

    Check for SPDX headers using pre-commit
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

---------

Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-02-02 11:58:18 -08:00

120 lines
3.8 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
load_chat_template)
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import VLLM_PATH
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
"""),
("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of"""),
("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
The capital of"""),
]
TEST_MESSAGES = [
{
'role': 'user',
'content': 'Hello'
},
{
'role': 'assistant',
'content': 'Hi there!'
},
{
'role': 'user',
'content': 'What is the capital of'
},
]
ASSISTANT_MESSAGE_TO_CONTINUE = {
'role': 'assistant',
'content': 'The capital of'
}
def test_load_chat_template():
# Testing chatml template
template_content = load_chat_template(chat_template=chatml_jinja_path)
# Test assertions
assert template_content is not None
# Hard coded value for template_chatml.jinja
assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
def test_no_load_chat_template_filelike():
# Testing chatml template
template = "../../examples/does_not_exist"
with pytest.raises(ValueError, match="looks like a file path"):
load_chat_template(chat_template=template)
def test_no_load_chat_template_literallike():
# Testing chatml template
template = "{{ messages }}"
template_content = load_chat_template(chat_template=template)
assert template_content == template
@pytest.mark.parametrize(
"model,template,add_generation_prompt,continue_final_message,expected_output",
MODEL_TEMPLATE_GENERATON_OUTPUT)
def test_get_gen_prompt(model, template, add_generation_prompt,
continue_final_message, expected_output):
# Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model)
template_content = load_chat_template(chat_template=template)
# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
model=model,
messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
if continue_final_message else TEST_MESSAGES,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
)
# Call the function and get the result
result = apply_hf_chat_template(
tokenizer,
conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content,
add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message,
)
# Test assertion
assert result == expected_output, (
f"The generated prompt does not match the expected output for "
f"model {model} and template {template}")