[Frontend] Move async logic outside of constructor (#4674)

This commit is contained in:
Cyrus Leung 2024-05-09 13:48:33 +08:00 committed by GitHub
parent 16bc0a098f
commit f12b20decc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 96 additions and 102 deletions

View File

@ -60,13 +60,12 @@ class MockServingChat:
tokenizer: MockTokenizer
@pytest.mark.asyncio
async def test_load_chat_template():
def test_load_chat_template():
# Testing chatml template
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer)
await OpenAIServingChat._load_chat_template(
mock_serving_chat, chat_template=chatml_jinja_path)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=chatml_jinja_path)
template_content = tokenizer.chat_template
@ -77,8 +76,7 @@ async def test_load_chat_template():
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
@pytest.mark.asyncio
async def test_no_load_chat_template_filelike():
def test_no_load_chat_template_filelike():
# Testing chatml template
template = "../../examples/does_not_exist"
tokenizer = MockTokenizer()
@ -86,35 +84,33 @@ async def test_no_load_chat_template_filelike():
mock_serving_chat = MockServingChat(tokenizer)
with pytest.raises(ValueError, match="looks like a file path"):
await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
@pytest.mark.asyncio
async def test_no_load_chat_template_literallike():
def test_no_load_chat_template_literallike():
# Testing chatml template
template = "{{ messages }}"
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer)
await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
template_content = tokenizer.chat_template
assert template_content == template
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model,template,add_generation_prompt,expected_output",
MODEL_TEMPLATE_GENERATON_OUTPUT)
async def test_get_gen_prompt(model, template, add_generation_prompt,
expected_output):
def test_get_gen_prompt(model, template, add_generation_prompt,
expected_output):
# Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model)
mock_serving_chat = MockServingChat(tokenizer)
await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(

View File

@ -20,11 +20,15 @@ class MockModelConfig:
class MockEngine:
async def get_model_config(self):
return MockModelConfig
return MockModelConfig()
async def _async_serving_chat_init():
serving_completion = OpenAIServingChat(MockEngine(),
engine = MockEngine()
model_config = await engine.get_model_config()
serving_completion = OpenAIServingChat(engine,
model_config,
served_model_names=[MODEL_NAME],
response_role="assistant",
chat_template=CHAT_TEMPLATE)

View File

@ -516,7 +516,7 @@ class EngineArgs:
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
def from_cli_args(cls, args: argparse.Namespace):
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.

View File

@ -4,7 +4,7 @@ import inspect
import re
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Set
from typing import Optional, Set
import fastapi
import uvicorn
@ -164,15 +164,32 @@ if __name__ == "__main__":
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
openai_serving_chat = OpenAIServingChat(engine, served_model_names,
event_loop: Optional[asyncio.AbstractEventLoop]
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None
if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
model_config = event_loop.run_until_complete(engine.get_model_config())
else:
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())
openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names,
args.response_role,
args.lora_modules,
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(
engine, served_model_names, args.lora_modules)
engine, model_config, served_model_names, args.lora_modules)
app.root_path = args.root_path
uvicorn.run(app,

View File

@ -1,4 +1,3 @@
import asyncio
import codecs
import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
@ -8,6 +7,7 @@ from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartParam,
ChatCompletionRole)
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
@ -35,17 +35,47 @@ class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template: Optional[str] = None):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
await_post_init=self._load_chat_template(
chat_template=chat_template))
lora_modules=lora_modules)
self.response_role = response_role
self._load_chat_template(chat_template)
def _load_chat_template(self, chat_template: Optional[str]):
tokenizer = self.tokenizer
if chat_template is not None:
try:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")
logger.info("Using supplied chat template:\n%s",
tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s",
tokenizer.chat_template)
else:
logger.warning(
"No chat template provided. Chat API will not work.")
def _parse_chat_message_content(
self,
@ -357,36 +387,4 @@ class OpenAIServingChat(OpenAIServing):
usage=usage,
)
return response
async def _load_chat_template(self, chat_template: Optional[str]):
while self.tokenizer is None:
# Give the parent class time to load the tokenizer
await asyncio.sleep(0.1)
tokenizer = self.tokenizer
if chat_template is not None:
try:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")
logger.info("Using supplied chat template:\n%s",
tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s",
tokenizer.chat_template)
else:
logger.warning(
"No chat template provided. Chat API will not work.")
return response

View File

@ -4,6 +4,7 @@ from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponse,
@ -52,11 +53,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
class OpenAIServingCompletion(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]] = None):
lora_modules: Optional[List[LoRAModulePath]]):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules)

View File

@ -1,13 +1,12 @@
import asyncio
import json
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest, ErrorResponse,
@ -29,13 +28,24 @@ class LoRAModulePath:
class OpenAIServing:
def __init__(self,
engine: AsyncLLMEngine,
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]],
await_post_init: Optional[Awaitable[Any]] = None):
lora_modules: Optional[List[LoRAModulePath]]):
super().__init__()
self.engine = engine
self.max_model_len = model_config.max_model_len
# A separate tokenizer to map token IDs to strings.
self.tokenizer = get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
tokenizer_revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
truncation_side="left")
self.served_model_names = served_model_names
if lora_modules is None:
self.lora_requests = []
else:
@ -47,38 +57,6 @@ class OpenAIServing:
) for i, lora in enumerate(lora_modules, start=1)
]
self.max_model_len = 0
# Lazy initialized
self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None
if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
event_loop.create_task(self._post_init(await_post_init))
else:
# When using single vLLM without engine_use_ray
asyncio.run(self._post_init(await_post_init))
async def _post_init(self, await_post_init):
engine_model_config = await self.engine.get_model_config()
self.max_model_len = engine_model_config.max_model_len
# A separate tokenizer to map token IDs to strings.
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
tokenizer_revision=engine_model_config.tokenizer_revision,
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")
if await_post_init is not None:
await await_post_init
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [