2024-01-14 12:37:58 -08:00
import os
import pathlib
2023-11-30 19:43:13 -05:00
import pytest
2024-07-16 12:18:09 +00:00
from vllm . entrypoints . openai . chat_utils import load_chat_template
2024-01-17 05:33:14 +00:00
from vllm . entrypoints . openai . protocol import ChatCompletionRequest
2024-03-25 23:59:47 +09:00
from vllm . transformers_utils . tokenizer import get_tokenizer
2023-11-30 19:43:13 -05:00
2024-01-14 12:37:58 -08:00
chatml_jinja_path = pathlib . Path ( os . path . dirname ( os . path . abspath (
__file__ ) ) ) . parent . parent / " examples/template_chatml.jinja "
assert chatml_jinja_path . exists ( )
2023-11-30 19:43:13 -05:00
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
( " facebook/opt-125m " , None , True ,
" Hello</s>Hi there!</s>What is the capital of</s> " ) ,
( " facebook/opt-125m " , None , False ,
" Hello</s>Hi there!</s>What is the capital of</s> " ) ,
2024-01-14 12:37:58 -08:00
( " facebook/opt-125m " , chatml_jinja_path , True , """ <|im_start|>user
2023-11-30 19:43:13 -05:00
Hello < | im_end | >
< | im_start | > assistant
Hi there ! < | im_end | >
< | im_start | > user
What is the capital of < | im_end | >
< | im_start | > assistant
""" ),
2024-01-14 12:37:58 -08:00
( " facebook/opt-125m " , chatml_jinja_path , False , """ <|im_start|>user
2023-11-30 19:43:13 -05:00
Hello < | im_end | >
< | im_start | > assistant
Hi there ! < | im_end | >
< | im_start | > user
What is the capital of """ )
]
TEST_MESSAGES = [
{
' role ' : ' user ' ,
' content ' : ' Hello '
} ,
{
' role ' : ' assistant ' ,
' content ' : ' Hi there! '
} ,
{
' role ' : ' user ' ,
' content ' : ' What is the capital of '
} ,
]
2024-05-09 13:48:33 +08:00
def test_load_chat_template ( ) :
2023-11-30 19:43:13 -05:00
# Testing chatml template
2024-07-18 00:13:30 -07:00
template_content = load_chat_template ( chat_template = chatml_jinja_path )
2023-11-30 19:43:13 -05:00
# Test assertions
assert template_content is not None
# Hard coded value for template_chatml.jinja
assert template_content == """ { % f or message in messages % } {{ ' <|im_start|> ' + message[ ' role ' ] + ' \\ n ' + message[ ' content ' ]}} { % i f (loop.last and add_generation_prompt) or not loop.last % } {{ ' <|im_end|> ' + ' \\ n ' }} { % e ndif % } { % e ndfor % }
2024-03-10 19:49:14 -07:00
{ % if add_generation_prompt and messages [ - 1 ] [ ' role ' ] != ' assistant ' % } { { ' <|im_start|>assistant \\ n ' } } { % endif % } """ # noqa: E501
2023-11-30 19:43:13 -05:00
2024-05-09 13:48:33 +08:00
def test_no_load_chat_template_filelike ( ) :
2023-11-30 19:43:13 -05:00
# Testing chatml template
template = " ../../examples/does_not_exist "
2024-04-24 02:19:03 +08:00
with pytest . raises ( ValueError , match = " looks like a file path " ) :
2024-07-18 00:13:30 -07:00
load_chat_template ( chat_template = template )
2024-04-24 02:19:03 +08:00
2024-05-09 13:48:33 +08:00
def test_no_load_chat_template_literallike ( ) :
2024-04-24 02:19:03 +08:00
# Testing chatml template
template = " {{ messages }} "
2024-07-18 00:13:30 -07:00
template_content = load_chat_template ( chat_template = template )
2023-11-30 19:43:13 -05:00
2024-04-24 02:19:03 +08:00
assert template_content == template
2023-11-30 19:43:13 -05:00
@pytest.mark.parametrize (
" model,template,add_generation_prompt,expected_output " ,
MODEL_TEMPLATE_GENERATON_OUTPUT )
2024-05-09 13:48:33 +08:00
def test_get_gen_prompt ( model , template , add_generation_prompt ,
expected_output ) :
2023-11-30 19:43:13 -05:00
# Initialize the tokenizer
tokenizer = get_tokenizer ( tokenizer_name = model )
2024-07-18 00:13:30 -07:00
template_content = load_chat_template ( chat_template = template )
2023-11-30 19:43:13 -05:00
# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest (
model = model ,
messages = TEST_MESSAGES ,
add_generation_prompt = add_generation_prompt )
# Call the function and get the result
result = tokenizer . apply_chat_template (
conversation = mock_request . messages ,
tokenize = False ,
2024-07-18 00:13:30 -07:00
add_generation_prompt = mock_request . add_generation_prompt ,
chat_template = mock_request . chat_template or template_content )
2023-11-30 19:43:13 -05:00
# Test assertion
2024-03-10 19:49:14 -07:00
assert result == expected_output , (
f " The generated prompt does not match the expected output for "
f " model { model } and template { template } " )