Fix/async chat serving (#2727)

This commit is contained in:
Sebastian Schoennenbeck 2024-05-03 20:04:14 +02:00 committed by GitHub
parent 7e65477e5e
commit f8e7adda21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 73 additions and 21 deletions

View File

@ -60,12 +60,13 @@ class MockServingChat:
tokenizer: MockTokenizer
def test_load_chat_template():
@pytest.mark.asyncio
async def test_load_chat_template():
# Testing chatml template
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=chatml_jinja_path)
await OpenAIServingChat._load_chat_template(
mock_serving_chat, chat_template=chatml_jinja_path)
template_content = tokenizer.chat_template
@ -76,7 +77,8 @@ def test_load_chat_template():
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
def test_no_load_chat_template_filelike():
@pytest.mark.asyncio
async def test_no_load_chat_template_filelike():
# Testing chatml template
template = "../../examples/does_not_exist"
tokenizer = MockTokenizer()
@ -84,18 +86,19 @@ def test_no_load_chat_template_filelike():
mock_serving_chat = MockServingChat(tokenizer)
with pytest.raises(ValueError, match="looks like a file path"):
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
def test_no_load_chat_template_literallike():
@pytest.mark.asyncio
async def test_no_load_chat_template_literallike():
# Testing chatml template
template = "{{ messages }}"
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
template_content = tokenizer.chat_template
assert template_content == template
@ -110,8 +113,8 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,
# Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model)
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(

View File

@ -0,0 +1,37 @@
import asyncio
from dataclasses import dataclass
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
@dataclass
class MockModelConfig:
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
@dataclass
class MockEngine:
async def get_model_config(self):
return MockModelConfig
async def _async_serving_chat_init():
serving_completion = OpenAIServingChat(MockEngine(),
served_model_names=[MODEL_NAME],
response_role="assistant",
chat_template=CHAT_TEMPLATE)
return serving_completion
def test_async_serving_chat_init():
serving_completion = asyncio.run(_async_serving_chat_init())
assert serving_completion.tokenizer is not None
assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE

View File

@ -150,7 +150,7 @@ def server(zephyr_lora_files):
ray.shutdown()
@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def client():
client = openai.AsyncOpenAI(
base_url="http://localhost:8000/v1",

View File

@ -1,3 +1,4 @@
import asyncio
import codecs
import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
@ -40,9 +41,11 @@ class OpenAIServingChat(OpenAIServing):
chat_template: Optional[str] = None):
super().__init__(engine=engine,
served_model_names=served_model_names,
lora_modules=lora_modules)
lora_modules=lora_modules,
await_post_init=self._load_chat_template(
chat_template=chat_template))
self.response_role = response_role
self._load_chat_template(chat_template)
def _parse_chat_message_content(
self,
@ -356,7 +359,10 @@ class OpenAIServingChat(OpenAIServing):
return response
def _load_chat_template(self, chat_template: Optional[str]):
async def _load_chat_template(self, chat_template: Optional[str]):
while self.tokenizer is None:
# Give the parent class time to load the tokenizer
await asyncio.sleep(0.1)
tokenizer = self.tokenizer
if chat_template is not None:

View File

@ -2,7 +2,7 @@ import asyncio
import json
from dataclasses import dataclass
from http import HTTPStatus
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
@ -29,8 +29,11 @@ class LoRAModulePath:
class OpenAIServing:
def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]]):
def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]],
await_post_init: Optional[Awaitable[Any]] = None):
self.engine = engine
self.served_model_names = served_model_names
if lora_modules is None:
@ -56,12 +59,12 @@ class OpenAIServing:
if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
event_loop.create_task(self._post_init())
event_loop.create_task(self._post_init(await_post_init))
else:
# When using single vLLM without engine_use_ray
asyncio.run(self._post_init())
asyncio.run(self._post_init(await_post_init))
async def _post_init(self):
async def _post_init(self, await_post_init):
engine_model_config = await self.engine.get_model_config()
self.max_model_len = engine_model_config.max_model_len
@ -73,6 +76,9 @@ class OpenAIServing:
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")
if await_post_init is not None:
await await_post_init
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [