2024-05-03 20:04:14 +02:00
|
|
|
import asyncio
|
2024-07-31 21:13:34 -07:00
|
|
|
from contextlib import suppress
|
2024-05-03 20:04:14 +02:00
|
|
|
from dataclasses import dataclass
|
2024-07-31 21:13:34 -07:00
|
|
|
from unittest.mock import MagicMock
|
2024-05-03 20:04:14 +02:00
|
|
|
|
2024-07-31 21:13:34 -07:00
|
|
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
|
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
2024-05-03 20:04:14 +02:00
|
|
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
2024-07-31 21:13:34 -07:00
|
|
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
2024-05-03 20:04:14 +02:00
|
|
|
|
|
|
|
MODEL_NAME = "openai-community/gpt2"
|
|
|
|
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class MockModelConfig:
|
|
|
|
tokenizer = MODEL_NAME
|
|
|
|
trust_remote_code = False
|
|
|
|
tokenizer_mode = "auto"
|
|
|
|
max_model_len = 100
|
|
|
|
tokenizer_revision = None
|
2024-05-11 11:30:37 -07:00
|
|
|
embedding_mode = False
|
2024-05-03 20:04:14 +02:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class MockEngine:
|
|
|
|
|
|
|
|
async def get_model_config(self):
|
2024-05-09 13:48:33 +08:00
|
|
|
return MockModelConfig()
|
2024-05-03 20:04:14 +02:00
|
|
|
|
|
|
|
|
|
|
|
async def _async_serving_chat_init():
|
2024-05-09 13:48:33 +08:00
|
|
|
engine = MockEngine()
|
|
|
|
model_config = await engine.get_model_config()
|
|
|
|
|
|
|
|
serving_completion = OpenAIServingChat(engine,
|
|
|
|
model_config,
|
2024-05-03 20:04:14 +02:00
|
|
|
served_model_names=[MODEL_NAME],
|
|
|
|
response_role="assistant",
|
2024-07-23 01:13:53 +08:00
|
|
|
chat_template=CHAT_TEMPLATE,
|
|
|
|
lora_modules=None,
|
|
|
|
prompt_adapters=None,
|
|
|
|
request_logger=None)
|
2024-05-03 20:04:14 +02:00
|
|
|
return serving_completion
|
|
|
|
|
|
|
|
|
|
|
|
def test_async_serving_chat_init():
|
|
|
|
serving_completion = asyncio.run(_async_serving_chat_init())
|
2024-07-18 00:13:30 -07:00
|
|
|
assert serving_completion.chat_template == CHAT_TEMPLATE
|
2024-07-31 21:13:34 -07:00
|
|
|
|
|
|
|
|
|
|
|
def test_serving_chat_should_set_correct_max_tokens():
|
|
|
|
mock_engine = MagicMock(spec=AsyncLLMEngine)
|
|
|
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
|
|
|
|
|
|
|
serving_chat = OpenAIServingChat(mock_engine,
|
|
|
|
MockModelConfig(),
|
|
|
|
served_model_names=[MODEL_NAME],
|
|
|
|
response_role="assistant",
|
|
|
|
chat_template=CHAT_TEMPLATE,
|
|
|
|
lora_modules=None,
|
|
|
|
prompt_adapters=None,
|
|
|
|
request_logger=None)
|
|
|
|
req = ChatCompletionRequest(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=[{
|
|
|
|
"role": "user",
|
|
|
|
"content": "what is 1+1?"
|
|
|
|
}],
|
|
|
|
guided_decoding_backend="outlines",
|
|
|
|
)
|
|
|
|
|
|
|
|
with suppress(Exception):
|
|
|
|
asyncio.run(serving_chat.create_chat_completion(req))
|
|
|
|
|
|
|
|
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
|
|
|
|
|
|
|
req.max_tokens = 10
|
|
|
|
with suppress(Exception):
|
|
|
|
asyncio.run(serving_chat.create_chat_completion(req))
|
|
|
|
|
|
|
|
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|