Fix/async chat serving (#2727)
This commit is contained in:
parent
7e65477e5e
commit
f8e7adda21
@ -60,12 +60,13 @@ class MockServingChat:
|
|||||||
tokenizer: MockTokenizer
|
tokenizer: MockTokenizer
|
||||||
|
|
||||||
|
|
||||||
def test_load_chat_template():
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_chat_template():
|
||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
tokenizer = MockTokenizer()
|
tokenizer = MockTokenizer()
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
await OpenAIServingChat._load_chat_template(
|
||||||
chat_template=chatml_jinja_path)
|
mock_serving_chat, chat_template=chatml_jinja_path)
|
||||||
|
|
||||||
template_content = tokenizer.chat_template
|
template_content = tokenizer.chat_template
|
||||||
|
|
||||||
@ -76,7 +77,8 @@ def test_load_chat_template():
|
|||||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
|
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
def test_no_load_chat_template_filelike():
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_load_chat_template_filelike():
|
||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
template = "../../examples/does_not_exist"
|
template = "../../examples/does_not_exist"
|
||||||
tokenizer = MockTokenizer()
|
tokenizer = MockTokenizer()
|
||||||
@ -84,18 +86,19 @@ def test_no_load_chat_template_filelike():
|
|||||||
mock_serving_chat = MockServingChat(tokenizer)
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="looks like a file path"):
|
with pytest.raises(ValueError, match="looks like a file path"):
|
||||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
await OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||||
chat_template=template)
|
chat_template=template)
|
||||||
|
|
||||||
|
|
||||||
def test_no_load_chat_template_literallike():
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_load_chat_template_literallike():
|
||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
template = "{{ messages }}"
|
template = "{{ messages }}"
|
||||||
tokenizer = MockTokenizer()
|
tokenizer = MockTokenizer()
|
||||||
|
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
await OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||||
chat_template=template)
|
chat_template=template)
|
||||||
template_content = tokenizer.chat_template
|
template_content = tokenizer.chat_template
|
||||||
|
|
||||||
assert template_content == template
|
assert template_content == template
|
||||||
@ -110,8 +113,8 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,
|
|||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = get_tokenizer(tokenizer_name=model)
|
tokenizer = get_tokenizer(tokenizer_name=model)
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
await OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||||
chat_template=template)
|
chat_template=template)
|
||||||
|
|
||||||
# Create a mock request object using keyword arguments
|
# Create a mock request object using keyword arguments
|
||||||
mock_request = ChatCompletionRequest(
|
mock_request = ChatCompletionRequest(
|
||||||
|
37
tests/entrypoints/openai/test_serving_chat.py
Normal file
37
tests/entrypoints/openai/test_serving_chat.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
|
|
||||||
|
MODEL_NAME = "openai-community/gpt2"
|
||||||
|
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockModelConfig:
|
||||||
|
tokenizer = MODEL_NAME
|
||||||
|
trust_remote_code = False
|
||||||
|
tokenizer_mode = "auto"
|
||||||
|
max_model_len = 100
|
||||||
|
tokenizer_revision = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockEngine:
|
||||||
|
|
||||||
|
async def get_model_config(self):
|
||||||
|
return MockModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_serving_chat_init():
|
||||||
|
serving_completion = OpenAIServingChat(MockEngine(),
|
||||||
|
served_model_names=[MODEL_NAME],
|
||||||
|
response_role="assistant",
|
||||||
|
chat_template=CHAT_TEMPLATE)
|
||||||
|
return serving_completion
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_serving_chat_init():
|
||||||
|
serving_completion = asyncio.run(_async_serving_chat_init())
|
||||||
|
assert serving_completion.tokenizer is not None
|
||||||
|
assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE
|
@ -150,7 +150,7 @@ def server(zephyr_lora_files):
|
|||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="module")
|
||||||
def client():
|
def client():
|
||||||
client = openai.AsyncOpenAI(
|
client = openai.AsyncOpenAI(
|
||||||
base_url="http://localhost:8000/v1",
|
base_url="http://localhost:8000/v1",
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import codecs
|
import codecs
|
||||||
import time
|
import time
|
||||||
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
|
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
|
||||||
@ -40,9 +41,11 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
chat_template: Optional[str] = None):
|
chat_template: Optional[str] = None):
|
||||||
super().__init__(engine=engine,
|
super().__init__(engine=engine,
|
||||||
served_model_names=served_model_names,
|
served_model_names=served_model_names,
|
||||||
lora_modules=lora_modules)
|
lora_modules=lora_modules,
|
||||||
|
await_post_init=self._load_chat_template(
|
||||||
|
chat_template=chat_template))
|
||||||
|
|
||||||
self.response_role = response_role
|
self.response_role = response_role
|
||||||
self._load_chat_template(chat_template)
|
|
||||||
|
|
||||||
def _parse_chat_message_content(
|
def _parse_chat_message_content(
|
||||||
self,
|
self,
|
||||||
@ -356,7 +359,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _load_chat_template(self, chat_template: Optional[str]):
|
async def _load_chat_template(self, chat_template: Optional[str]):
|
||||||
|
while self.tokenizer is None:
|
||||||
|
# Give the parent class time to load the tokenizer
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
tokenizer = self.tokenizer
|
tokenizer = self.tokenizer
|
||||||
|
|
||||||
if chat_template is not None:
|
if chat_template is not None:
|
||||||
|
@ -2,7 +2,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
@ -29,8 +29,11 @@ class LoRAModulePath:
|
|||||||
|
|
||||||
class OpenAIServing:
|
class OpenAIServing:
|
||||||
|
|
||||||
def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str],
|
def __init__(self,
|
||||||
lora_modules: Optional[List[LoRAModulePath]]):
|
engine: AsyncLLMEngine,
|
||||||
|
served_model_names: List[str],
|
||||||
|
lora_modules: Optional[List[LoRAModulePath]],
|
||||||
|
await_post_init: Optional[Awaitable[Any]] = None):
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
self.served_model_names = served_model_names
|
self.served_model_names = served_model_names
|
||||||
if lora_modules is None:
|
if lora_modules is None:
|
||||||
@ -56,12 +59,12 @@ class OpenAIServing:
|
|||||||
if event_loop is not None and event_loop.is_running():
|
if event_loop is not None and event_loop.is_running():
|
||||||
# If the current is instanced by Ray Serve,
|
# If the current is instanced by Ray Serve,
|
||||||
# there is already a running event loop
|
# there is already a running event loop
|
||||||
event_loop.create_task(self._post_init())
|
event_loop.create_task(self._post_init(await_post_init))
|
||||||
else:
|
else:
|
||||||
# When using single vLLM without engine_use_ray
|
# When using single vLLM without engine_use_ray
|
||||||
asyncio.run(self._post_init())
|
asyncio.run(self._post_init(await_post_init))
|
||||||
|
|
||||||
async def _post_init(self):
|
async def _post_init(self, await_post_init):
|
||||||
engine_model_config = await self.engine.get_model_config()
|
engine_model_config = await self.engine.get_model_config()
|
||||||
self.max_model_len = engine_model_config.max_model_len
|
self.max_model_len = engine_model_config.max_model_len
|
||||||
|
|
||||||
@ -73,6 +76,9 @@ class OpenAIServing:
|
|||||||
trust_remote_code=engine_model_config.trust_remote_code,
|
trust_remote_code=engine_model_config.trust_remote_code,
|
||||||
truncation_side="left")
|
truncation_side="left")
|
||||||
|
|
||||||
|
if await_post_init is not None:
|
||||||
|
await await_post_init
|
||||||
|
|
||||||
async def show_available_models(self) -> ModelList:
|
async def show_available_models(self) -> ModelList:
|
||||||
"""Show available models. Right now we only have one model."""
|
"""Show available models. Right now we only have one model."""
|
||||||
model_cards = [
|
model_cards = [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user