Fix/async chat serving (#2727)

This commit is contained in:
Sebastian Schoennenbeck 2024-05-03 20:04:14 +02:00 committed by GitHub
parent 7e65477e5e
commit f8e7adda21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 73 additions and 21 deletions

View File

@ -60,12 +60,13 @@ class MockServingChat:
tokenizer: MockTokenizer tokenizer: MockTokenizer
def test_load_chat_template(): @pytest.mark.asyncio
async def test_load_chat_template():
# Testing chatml template # Testing chatml template
tokenizer = MockTokenizer() tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer) mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat, await OpenAIServingChat._load_chat_template(
chat_template=chatml_jinja_path) mock_serving_chat, chat_template=chatml_jinja_path)
template_content = tokenizer.chat_template template_content = tokenizer.chat_template
@ -76,7 +77,8 @@ def test_load_chat_template():
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501 {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
def test_no_load_chat_template_filelike(): @pytest.mark.asyncio
async def test_no_load_chat_template_filelike():
# Testing chatml template # Testing chatml template
template = "../../examples/does_not_exist" template = "../../examples/does_not_exist"
tokenizer = MockTokenizer() tokenizer = MockTokenizer()
@ -84,18 +86,19 @@ def test_no_load_chat_template_filelike():
mock_serving_chat = MockServingChat(tokenizer) mock_serving_chat = MockServingChat(tokenizer)
with pytest.raises(ValueError, match="looks like a file path"): with pytest.raises(ValueError, match="looks like a file path"):
OpenAIServingChat._load_chat_template(mock_serving_chat, await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template) chat_template=template)
def test_no_load_chat_template_literallike(): @pytest.mark.asyncio
async def test_no_load_chat_template_literallike():
# Testing chatml template # Testing chatml template
template = "{{ messages }}" template = "{{ messages }}"
tokenizer = MockTokenizer() tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer) mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat, await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template) chat_template=template)
template_content = tokenizer.chat_template template_content = tokenizer.chat_template
assert template_content == template assert template_content == template
@ -110,8 +113,8 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,
# Initialize the tokenizer # Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model) tokenizer = get_tokenizer(tokenizer_name=model)
mock_serving_chat = MockServingChat(tokenizer) mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat, await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template) chat_template=template)
# Create a mock request object using keyword arguments # Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest( mock_request = ChatCompletionRequest(

View File

@ -0,0 +1,37 @@
import asyncio
from dataclasses import dataclass
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
@dataclass
class MockModelConfig:
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
@dataclass
class MockEngine:
async def get_model_config(self):
return MockModelConfig
async def _async_serving_chat_init():
serving_completion = OpenAIServingChat(MockEngine(),
served_model_names=[MODEL_NAME],
response_role="assistant",
chat_template=CHAT_TEMPLATE)
return serving_completion
def test_async_serving_chat_init():
serving_completion = asyncio.run(_async_serving_chat_init())
assert serving_completion.tokenizer is not None
assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE

View File

@ -150,7 +150,7 @@ def server(zephyr_lora_files):
ray.shutdown() ray.shutdown()
@pytest.fixture(scope="session") @pytest.fixture(scope="module")
def client(): def client():
client = openai.AsyncOpenAI( client = openai.AsyncOpenAI(
base_url="http://localhost:8000/v1", base_url="http://localhost:8000/v1",

View File

@ -1,3 +1,4 @@
import asyncio
import codecs import codecs
import time import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
@ -40,9 +41,11 @@ class OpenAIServingChat(OpenAIServing):
chat_template: Optional[str] = None): chat_template: Optional[str] = None):
super().__init__(engine=engine, super().__init__(engine=engine,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules) lora_modules=lora_modules,
await_post_init=self._load_chat_template(
chat_template=chat_template))
self.response_role = response_role self.response_role = response_role
self._load_chat_template(chat_template)
def _parse_chat_message_content( def _parse_chat_message_content(
self, self,
@ -356,7 +359,10 @@ class OpenAIServingChat(OpenAIServing):
return response return response
def _load_chat_template(self, chat_template: Optional[str]): async def _load_chat_template(self, chat_template: Optional[str]):
while self.tokenizer is None:
# Give the parent class time to load the tokenizer
await asyncio.sleep(0.1)
tokenizer = self.tokenizer tokenizer = self.tokenizer
if chat_template is not None: if chat_template is not None:

View File

@ -2,7 +2,7 @@ import asyncio
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Tuple, Union from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
from pydantic import Field from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
@ -29,8 +29,11 @@ class LoRAModulePath:
class OpenAIServing: class OpenAIServing:
def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], def __init__(self,
lora_modules: Optional[List[LoRAModulePath]]): engine: AsyncLLMEngine,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]],
await_post_init: Optional[Awaitable[Any]] = None):
self.engine = engine self.engine = engine
self.served_model_names = served_model_names self.served_model_names = served_model_names
if lora_modules is None: if lora_modules is None:
@ -56,12 +59,12 @@ class OpenAIServing:
if event_loop is not None and event_loop.is_running(): if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve, # If the current is instanced by Ray Serve,
# there is already a running event loop # there is already a running event loop
event_loop.create_task(self._post_init()) event_loop.create_task(self._post_init(await_post_init))
else: else:
# When using single vLLM without engine_use_ray # When using single vLLM without engine_use_ray
asyncio.run(self._post_init()) asyncio.run(self._post_init(await_post_init))
async def _post_init(self): async def _post_init(self, await_post_init):
engine_model_config = await self.engine.get_model_config() engine_model_config = await self.engine.get_model_config()
self.max_model_len = engine_model_config.max_model_len self.max_model_len = engine_model_config.max_model_len
@ -73,6 +76,9 @@ class OpenAIServing:
trust_remote_code=engine_model_config.trust_remote_code, trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left") truncation_side="left")
if await_post_init is not None:
await await_post_init
async def show_available_models(self) -> ModelList: async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
model_cards = [ model_cards = [