OpenAI Server refactoring (#2360)
This commit is contained in:
parent
e1957c6ebd
commit
14cc317ba4
@ -19,6 +19,9 @@ steps:
|
||||
- label: Engine Test
|
||||
command: pytest -v -s engine
|
||||
|
||||
- label: Entrypoints Test
|
||||
command: pytest -v -s entrypoints
|
||||
|
||||
- label: Kernels Test
|
||||
command: pytest -v -s kernels
|
||||
soft_fail: true
|
||||
|
@ -16,3 +16,6 @@ pytest-asyncio
|
||||
httpx
|
||||
einops # required for MPT
|
||||
flash_attn # required for HuggingFace's llama implementation
|
||||
openai
|
||||
requests
|
||||
ray
|
@ -1,12 +1,12 @@
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from vllm.entrypoints.openai.api_server import *
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
|
||||
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
|
||||
__file__))).parent.parent / "examples/template_chatml.jinja"
|
||||
@ -48,7 +48,6 @@ TEST_MESSAGES = [
|
||||
'content': 'What is the capital of'
|
||||
},
|
||||
]
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -56,13 +55,17 @@ class MockTokenizer:
|
||||
chat_template = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockServingChat:
|
||||
tokenizer: MockTokenizer
|
||||
|
||||
|
||||
def test_load_chat_template():
|
||||
# Testing chatml template
|
||||
mock_args = Namespace(chat_template=chatml_jinja_path)
|
||||
tokenizer = MockTokenizer()
|
||||
|
||||
# Call the function with the mocked args
|
||||
load_chat_template(mock_args, tokenizer)
|
||||
mock_serving_chat = MockServingChat(tokenizer)
|
||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||
chat_template=chatml_jinja_path)
|
||||
|
||||
template_content = tokenizer.chat_template
|
||||
|
||||
@ -76,11 +79,11 @@ def test_load_chat_template():
|
||||
def test_no_load_chat_template():
|
||||
# Testing chatml template
|
||||
template = "../../examples/does_not_exist"
|
||||
mock_args = Namespace(chat_template=template)
|
||||
tokenizer = MockTokenizer()
|
||||
|
||||
# Call the function with the mocked args
|
||||
load_chat_template(mock_args, tokenizer=tokenizer)
|
||||
mock_serving_chat = MockServingChat(tokenizer)
|
||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||
chat_template=template)
|
||||
template_content = tokenizer.chat_template
|
||||
|
||||
# Test assertions
|
||||
@ -97,9 +100,9 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
expected_output):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = get_tokenizer(tokenizer_name=model)
|
||||
|
||||
mock_args = Namespace(chat_template=template)
|
||||
load_chat_template(mock_args, tokenizer)
|
||||
mock_serving_chat = MockServingChat(tokenizer)
|
||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||
chat_template=template)
|
||||
|
||||
# Create a mock request object using keyword arguments
|
||||
mock_request = ChatCompletionRequest(
|
||||
@ -115,8 +118,3 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
|
||||
# Test assertion
|
||||
assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"
|
||||
|
||||
|
||||
def test_health_endpoint():
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
193
tests/entrypoints/test_openai_server.py
Normal file
193
tests/entrypoints/test_openai_server.py
Normal file
@ -0,0 +1,193 @@
|
||||
import time
|
||||
import subprocess
|
||||
|
||||
import sys
|
||||
import pytest
|
||||
import requests
|
||||
import ray # using Ray for overall ease of process management, parallel requests, and debugging.
|
||||
import openai # use the official client for correctness check
|
||||
|
||||
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class ServerRunner:
|
||||
|
||||
def __init__(self, args):
|
||||
self.proc = subprocess.Popen(
|
||||
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
self._wait_for_server()
|
||||
|
||||
def ready(self):
|
||||
return True
|
||||
|
||||
def _wait_for_server(self):
|
||||
# run health check
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
if requests.get(
|
||||
"http://localhost:8000/health").status_code == 200:
|
||||
break
|
||||
except Exception as err:
|
||||
if self.proc.poll() is not None:
|
||||
raise RuntimeError("Server exited unexpectedly.") from err
|
||||
|
||||
time.sleep(0.5)
|
||||
if time.time() - start > MAX_SERVER_START_WAIT_S:
|
||||
raise RuntimeError(
|
||||
"Server failed to start in time.") from err
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "proc"):
|
||||
self.proc.terminate()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
ray.init()
|
||||
server_runner = ServerRunner.remote([
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
"--dtype",
|
||||
"bfloat16", # use half precision for speed and memory savings in CI environment
|
||||
"--max-model-len",
|
||||
"8192"
|
||||
])
|
||||
ray.get(server_runner.ready.remote())
|
||||
yield server_runner
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client():
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key="token-abc123",
|
||||
)
|
||||
yield client
|
||||
|
||||
|
||||
async def test_single_completion(server, client: openai.AsyncOpenAI):
|
||||
completion = await client.completions.create(model=MODEL_NAME,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
assert completion.choices[0].text is not None and len(
|
||||
completion.choices[0].text) >= 5
|
||||
assert completion.choices[0].finish_reason == "length"
|
||||
assert completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||
|
||||
|
||||
async def test_single_chat_session(server, client: openai.AsyncOpenAI):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
assert chat_completion.id is not None
|
||||
assert chat_completion.choices is not None and len(
|
||||
chat_completion.choices) == 1
|
||||
assert chat_completion.choices[0].message is not None
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 10
|
||||
assert message.role == "assistant"
|
||||
messages.append({"role": "assistant", "content": message.content})
|
||||
|
||||
# test multi-turn dialogue
|
||||
messages.append({"role": "user", "content": "express your result in json"})
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
||||
|
||||
async def test_completion_streaming(server, client: openai.AsyncOpenAI):
|
||||
prompt = "What is an LLM?"
|
||||
|
||||
single_completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
single_output = single_completion.choices[0].text
|
||||
single_usage = single_completion.usage
|
||||
|
||||
stream = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
chunks = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk.choices[0].text)
|
||||
assert chunk.choices[0].finish_reason == "length"
|
||||
assert chunk.usage == single_usage
|
||||
assert "".join(chunks) == single_output
|
||||
|
||||
|
||||
async def test_chat_streaming(server, client: openai.AsyncOpenAI):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
output = chat_completion.choices[0].message.content
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
|
||||
# test streaming
|
||||
stream = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
chunks = []
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.role:
|
||||
assert delta.role == "assistant"
|
||||
if delta.content:
|
||||
chunks.append(delta.content)
|
||||
assert chunk.choices[0].finish_reason == stop_reason
|
||||
assert "".join(chunks) == output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
@ -1,19 +1,12 @@
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import codecs
|
||||
import json
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from aioprometheus import MetricsMiddleware
|
||||
from aioprometheus.asgi.starlette import metrics
|
||||
import fastapi
|
||||
import uvicorn
|
||||
from http import HTTPStatus
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@ -22,26 +15,16 @@ from fastapi.responses import JSONResponse, StreamingResponse, Response
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.metrics import add_global_metrics_labels
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice, CompletionStreamResponse,
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
|
||||
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
openai_serving_chat: OpenAIServingChat = None
|
||||
openai_serving_completion: OpenAIServingCompletion = None
|
||||
logger = init_logger(__name__)
|
||||
served_model = None
|
||||
engine_args = None
|
||||
engine = None
|
||||
response_role = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@ -120,72 +103,10 @@ app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics
|
||||
app.add_route("/metrics", metrics) # Exposes HTTP metrics
|
||||
|
||||
|
||||
def create_error_response(status_code: HTTPStatus,
|
||||
message: str) -> JSONResponse:
|
||||
return JSONResponse(ErrorResponse(message=message,
|
||||
type="invalid_request_error").dict(),
|
||||
status_code=status_code.value)
|
||||
|
||||
|
||||
def load_chat_template(args, tokenizer):
|
||||
if args.chat_template is not None:
|
||||
try:
|
||||
with open(args.chat_template, "r") as f:
|
||||
chat_template = f.read()
|
||||
except OSError:
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
chat_template = codecs.decode(args.chat_template, "unicode_escape")
|
||||
|
||||
tokenizer.chat_template = chat_template
|
||||
logger.info(
|
||||
f"Using supplied chat template:\n{tokenizer.chat_template}")
|
||||
elif tokenizer.chat_template is not None:
|
||||
logger.info(f"Using default chat template:\n{tokenizer.chat_template}")
|
||||
else:
|
||||
logger.warning("No chat template provided. Chat API will not work.")
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
|
||||
|
||||
|
||||
async def check_model(request) -> Optional[JSONResponse]:
|
||||
if request.model == served_model:
|
||||
return
|
||||
ret = create_error_response(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"The model `{request.model}` does not exist.",
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
async def check_length(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None
|
||||
) -> Tuple[List[int], Optional[JSONResponse]]:
|
||||
assert (not (prompt is None and prompt_ids is None)
|
||||
and not (prompt is not None and prompt_ids is not None)
|
||||
), "Either prompt or prompt_ids should be provided."
|
||||
input_ids = prompt_ids if prompt_ids is not None else tokenizer(
|
||||
prompt).input_ids
|
||||
token_num = len(input_ids)
|
||||
|
||||
if request.max_tokens is None:
|
||||
request.max_tokens = max_model_len - token_num
|
||||
if token_num + request.max_tokens > max_model_len:
|
||||
return input_ids, create_error_response(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
f"This model's maximum context length is {max_model_len} tokens. "
|
||||
f"However, you requested {request.max_tokens + token_num} tokens "
|
||||
f"({token_num} in the messages, "
|
||||
f"{request.max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.",
|
||||
)
|
||||
else:
|
||||
return input_ids, None
|
||||
err = openai_serving_chat.create_error_response(message=str(exc))
|
||||
return JSONResponse(err.dict(), status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@ -196,544 +117,31 @@ async def health() -> Response:
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def show_available_models():
|
||||
"""Show available models. Right now we only have one model."""
|
||||
model_cards = [
|
||||
ModelCard(id=served_model,
|
||||
root=served_model,
|
||||
permission=[ModelPermission()])
|
||||
]
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
|
||||
def create_logprobs(
|
||||
token_ids: List[int],
|
||||
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
initial_text_offset: int = 0,
|
||||
) -> LogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs = []
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is not None:
|
||||
token_logprob = step_top_logprobs[token_id]
|
||||
else:
|
||||
token_logprob = None
|
||||
token = tokenizer.convert_ids_to_tokens(token_id)
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
||||
last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append({
|
||||
tokenizer.convert_ids_to_tokens(i): p
|
||||
for i, p in step_top_logprobs.items()
|
||||
} if step_top_logprobs else None)
|
||||
return logprobs
|
||||
models = await openai_serving_chat.show_available_models()
|
||||
return JSONResponse(content=models.dict())
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/chat/create
|
||||
for the API specification. This API mimics the OpenAI ChatCompletion API.
|
||||
|
||||
NOTE: Currently we do not support the following features:
|
||||
- function_call (Users should implement this by themselves)
|
||||
- logit_bias (to be supported by vLLM engine)
|
||||
"""
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
# TODO: support logit_bias in vLLM engine.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
try:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=request.messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.add_generation_prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in applying chat template from request: {str(e)}")
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
created_time = int(time.monotonic())
|
||||
chunk_object_type = "chat.completion.chunk"
|
||||
try:
|
||||
spaces_between_special_tokens = request.spaces_between_special_tokens
|
||||
sampling_params = SamplingParams(
|
||||
n=request.n,
|
||||
presence_penalty=request.presence_penalty,
|
||||
frequency_penalty=request.frequency_penalty,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
min_p=request.min_p,
|
||||
stop=request.stop,
|
||||
stop_token_ids=request.stop_token_ids,
|
||||
max_tokens=request.max_tokens,
|
||||
best_of=request.best_of,
|
||||
top_k=request.top_k,
|
||||
ignore_eos=request.ignore_eos,
|
||||
use_beam_search=request.use_beam_search,
|
||||
skip_special_tokens=request.skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||
token_ids)
|
||||
|
||||
def get_role() -> str:
|
||||
if request.add_generation_prompt:
|
||||
return response_role
|
||||
else:
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
||||
# Send first response for each request.n (index) with the role
|
||||
role = get_role()
|
||||
for i in range(request.n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, delta=DeltaMessage(role=role), finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response to echo the input portion of the last message
|
||||
if request.echo:
|
||||
last_msg_content = ""
|
||||
if request.messages and isinstance(
|
||||
request.messages, list) and request.messages[-1].get(
|
||||
"content") and request.messages[-1].get(
|
||||
"role") == role:
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
if last_msg_content:
|
||||
for i in range(request.n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=last_msg_content),
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
finish_reason_sent = [False] * request.n
|
||||
async for res in result_generator:
|
||||
res: RequestOutput
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Send token-by-token response for each request.n
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
else:
|
||||
# Send the finish response for each request.n only once
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=previous_num_tokens[i],
|
||||
total_tokens=prompt_tokens + previous_num_tokens[i],
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, delta=[], finish_reason=output.finish_reason)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
if final_usage is not None:
|
||||
chunk.usage = final_usage
|
||||
data = chunk.json(exclude_unset=True,
|
||||
exclude_none=True,
|
||||
ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
finish_reason_sent[i] = True
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def completion_full_generator():
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await engine.abort(request_id)
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"Client disconnected")
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
|
||||
choices = []
|
||||
role = get_role()
|
||||
for output in final_res.outputs:
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role=role, content=output.text),
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
if request.echo:
|
||||
last_msg_content = ""
|
||||
if request.messages and isinstance(
|
||||
request.messages, list) and request.messages[-1].get(
|
||||
"content") and request.messages[-1].get(
|
||||
"role") == role:
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
|
||||
for choice in choices:
|
||||
full_message = last_msg_content + choice.message.content
|
||||
choice.message.content = full_message
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return StreamingResponse(completion_stream_generator(),
|
||||
generator = await openai_serving_chat.create_chat_completion(
|
||||
request, raw_request)
|
||||
if request.stream and not isinstance(generator, ErrorResponse):
|
||||
return StreamingResponse(content=generator,
|
||||
media_type="text/event-stream")
|
||||
else:
|
||||
return await completion_full_generator()
|
||||
return JSONResponse(content=generator.dict())
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following features:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
- logit_bias (to be supported by vLLM engine)
|
||||
"""
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# OpenAI API supports echoing the prompt when max_tokens is 0.
|
||||
echo_without_generation = request.echo and request.max_tokens == 0
|
||||
|
||||
if request.suffix is not None:
|
||||
# The language models we currently support do not support suffix.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"suffix is not currently supported")
|
||||
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
# TODO: support logit_bias in vLLM engine.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
|
||||
use_token_ids = False
|
||||
if isinstance(request.prompt, list):
|
||||
if len(request.prompt) == 0:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"please provide at least one prompt")
|
||||
first_element = request.prompt[0]
|
||||
if isinstance(first_element, int):
|
||||
use_token_ids = True
|
||||
prompt = request.prompt
|
||||
elif isinstance(first_element, (str, list)):
|
||||
# TODO: handles multiple prompt case in list[list[int]]
|
||||
if len(request.prompt) > 1:
|
||||
return create_error_response(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"multiple prompts in a batch is not currently supported")
|
||||
use_token_ids = not isinstance(first_element, str)
|
||||
prompt = request.prompt[0]
|
||||
else:
|
||||
prompt = request.prompt
|
||||
|
||||
if use_token_ids:
|
||||
_, error_check_ret = await check_length(request, prompt_ids=prompt)
|
||||
else:
|
||||
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
created_time = int(time.monotonic())
|
||||
try:
|
||||
spaces_between_special_tokens = request.spaces_between_special_tokens
|
||||
sampling_params = SamplingParams(
|
||||
n=request.n,
|
||||
best_of=request.best_of,
|
||||
presence_penalty=request.presence_penalty,
|
||||
frequency_penalty=request.frequency_penalty,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
min_p=request.min_p,
|
||||
stop=request.stop,
|
||||
stop_token_ids=request.stop_token_ids,
|
||||
ignore_eos=request.ignore_eos,
|
||||
max_tokens=request.max_tokens
|
||||
if not echo_without_generation else 1,
|
||||
logprobs=request.logprobs,
|
||||
use_beam_search=request.use_beam_search,
|
||||
prompt_logprobs=request.logprobs if request.echo else None,
|
||||
skip_special_tokens=request.skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
if use_token_ids:
|
||||
result_generator = engine.generate(None,
|
||||
sampling_params,
|
||||
request_id,
|
||||
prompt_token_ids=prompt)
|
||||
else:
|
||||
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||
token_ids)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use beam search.
|
||||
stream = (request.stream
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
|
||||
def create_stream_response_json(
|
||||
index: int,
|
||||
text: str,
|
||||
logprobs: Optional[LogProbs] = None,
|
||||
finish_reason: Optional[str] = None,
|
||||
usage: Optional[UsageInfo] = None,
|
||||
) -> str:
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
text=text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
response = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[choice_data],
|
||||
)
|
||||
if usage is not None:
|
||||
response.usage = usage
|
||||
response_json = response.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
return response_json
|
||||
|
||||
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
has_echoed = [False] * request.n
|
||||
async for res in result_generator:
|
||||
res: RequestOutput
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
token_ids = output.token_ids[previous_num_tokens[i]:]
|
||||
if request.logprobs is not None:
|
||||
top_logprobs = output.logprobs[previous_num_tokens[i]:]
|
||||
else:
|
||||
top_logprobs = None
|
||||
offsets = len(previous_texts[i])
|
||||
if request.echo and not has_echoed[i]:
|
||||
if not echo_without_generation:
|
||||
delta_text = res.prompt + delta_text
|
||||
token_ids = res.prompt_token_ids + token_ids
|
||||
if top_logprobs:
|
||||
top_logprobs = res.prompt_logprobs + top_logprobs
|
||||
else: # only just return the prompt
|
||||
delta_text = res.prompt
|
||||
token_ids = res.prompt_token_ids
|
||||
if top_logprobs:
|
||||
top_logprobs = res.prompt_logprobs
|
||||
has_echoed[i] = True
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
initial_text_offset=offsets,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
finish_reason = output.finish_reason
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
if output.finish_reason is not None:
|
||||
logprobs = (LogProbs()
|
||||
if request.logprobs is not None else None)
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
completion_tokens = len(output.token_ids)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text="",
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
usage=final_usage,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
return StreamingResponse(completion_stream_generator(),
|
||||
generator = await openai_serving_completion.create_completion(
|
||||
request, raw_request)
|
||||
if request.stream and not isinstance(generator, ErrorResponse):
|
||||
return StreamingResponse(content=generator,
|
||||
media_type="text/event-stream")
|
||||
|
||||
# Non-streaming response
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await engine.abort(request_id)
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"Client disconnected")
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
choices = []
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
prompt_logprobs = final_res.prompt_logprobs
|
||||
prompt_text = final_res.prompt
|
||||
for output in final_res.outputs:
|
||||
if request.logprobs is not None:
|
||||
if not echo_without_generation:
|
||||
token_ids = output.token_ids
|
||||
top_logprobs = output.logprobs
|
||||
if request.echo:
|
||||
token_ids = prompt_token_ids + token_ids
|
||||
top_logprobs = prompt_logprobs + top_logprobs
|
||||
else:
|
||||
token_ids = prompt_token_ids
|
||||
top_logprobs = prompt_logprobs
|
||||
logprobs = create_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
if not echo_without_generation:
|
||||
output_text = output.text
|
||||
if request.echo:
|
||||
output_text = prompt_text + output_text
|
||||
else:
|
||||
output_text = prompt_text
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=output.index,
|
||||
text=output_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = CompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
response_json = response.json(ensure_ascii=False)
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(fake_stream_generator(),
|
||||
media_type="text/event-stream")
|
||||
|
||||
return response
|
||||
else:
|
||||
return JSONResponse(content=generator.dict())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -754,19 +162,12 @@ if __name__ == "__main__":
|
||||
else:
|
||||
served_model = args.model
|
||||
|
||||
response_role = args.response_role
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine_model_config = asyncio.run(engine.get_model_config())
|
||||
max_model_len = engine_model_config.max_model_len
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(
|
||||
engine_model_config.tokenizer,
|
||||
tokenizer_mode=engine_model_config.tokenizer_mode,
|
||||
trust_remote_code=engine_model_config.trust_remote_code)
|
||||
load_chat_template(args, tokenizer)
|
||||
openai_serving_chat = OpenAIServingChat(engine, served_model,
|
||||
args.response_role,
|
||||
args.chat_template)
|
||||
openai_serving_completion = OpenAIServingCompletion(engine, served_model)
|
||||
|
||||
# Register labels for metrics
|
||||
add_global_metrics_labels(model_name=engine_args.model)
|
||||
|
288
vllm/entrypoints/openai/serving_chat.py
Normal file
288
vllm/entrypoints/openai/serving_chat.py
Normal file
@ -0,0 +1,288 @@
|
||||
import time
|
||||
import codecs
|
||||
from fastapi import Request
|
||||
from typing import AsyncGenerator, AsyncIterator, Union
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
UsageInfo)
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
engine: AsyncLLMEngine,
|
||||
served_model: str,
|
||||
response_role: str,
|
||||
chat_template=None):
|
||||
super().__init__(engine=engine, served_model=served_model)
|
||||
self.response_role = response_role
|
||||
self._load_chat_template(chat_template)
|
||||
|
||||
async def create_chat_completion(
|
||||
self, request: ChatCompletionRequest, raw_request: Request
|
||||
) -> Union[ErrorResponse, AsyncGenerator[str, None],
|
||||
ChatCompletionResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/chat/create
|
||||
for the API specification. This API mimics the OpenAI ChatCompletion API.
|
||||
|
||||
NOTE: Currently we do not support the following features:
|
||||
- function_call (Users should implement this by themselves)
|
||||
- logit_bias (to be supported by vLLM engine)
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
# TODO: support logit_bias in vLLM engine.
|
||||
return self.create_error_response(
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
try:
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
conversation=request.messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.add_generation_prompt)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in applying chat template from request: {str(e)}")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
token_ids, error_check_ret = await self._check_length(request,
|
||||
prompt=prompt)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
try:
|
||||
spaces_between_special_tokens = request.spaces_between_special_tokens
|
||||
sampling_params = SamplingParams(
|
||||
n=request.n,
|
||||
presence_penalty=request.presence_penalty,
|
||||
frequency_penalty=request.frequency_penalty,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
min_p=request.min_p,
|
||||
stop=request.stop,
|
||||
stop_token_ids=request.stop_token_ids,
|
||||
max_tokens=request.max_tokens,
|
||||
best_of=request.best_of,
|
||||
top_k=request.top_k,
|
||||
ignore_eos=request.ignore_eos,
|
||||
use_beam_search=request.use_beam_search,
|
||||
skip_special_tokens=request.skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = self.engine.generate(prompt, sampling_params,
|
||||
request_id, token_ids)
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request, result_generator, request_id)
|
||||
else:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, raw_request, result_generator, request_id)
|
||||
|
||||
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
||||
if request.add_generation_prompt:
|
||||
return self.response_role
|
||||
else:
|
||||
return request.messages[-1].role
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self, request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput], request_id: str
|
||||
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
|
||||
|
||||
model_name = request.model
|
||||
created_time = int(time.monotonic())
|
||||
chunk_object_type = "chat.completion.chunk"
|
||||
|
||||
# Send first response for each request.n (index) with the role
|
||||
role = self.get_chat_request_role(request)
|
||||
for i in range(request.n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, delta=DeltaMessage(role=role), finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response to echo the input portion of the last message
|
||||
if request.echo:
|
||||
last_msg_content = ""
|
||||
if request.messages and isinstance(
|
||||
request.messages, list) and request.messages[-1].get(
|
||||
"content") and request.messages[-1].get(
|
||||
"role") == role:
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
if last_msg_content:
|
||||
for i in range(request.n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=last_msg_content),
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
finish_reason_sent = [False] * request.n
|
||||
async for res in result_generator:
|
||||
res: RequestOutput
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Send token-by-token response for each request.n
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
else:
|
||||
# Send the finish response for each request.n only once
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=previous_num_tokens[i],
|
||||
total_tokens=prompt_tokens + previous_num_tokens[i],
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
finish_reason=output.finish_reason)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
if final_usage is not None:
|
||||
chunk.usage = final_usage
|
||||
data = chunk.json(exclude_unset=True,
|
||||
exclude_none=True,
|
||||
ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
finish_reason_sent[i] = True
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
self, request: ChatCompletionRequest, raw_request: Request,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
|
||||
model_name = request.model
|
||||
created_time = int(time.monotonic())
|
||||
final_res: RequestOutput = None
|
||||
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(request_id)
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
|
||||
choices = []
|
||||
role = self.get_chat_request_role(request)
|
||||
for output in final_res.outputs:
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role=role, content=output.text),
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
if request.echo:
|
||||
last_msg_content = ""
|
||||
if request.messages and isinstance(
|
||||
request.messages, list) and request.messages[-1].get(
|
||||
"content") and request.messages[-1].get(
|
||||
"role") == role:
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
|
||||
for choice in choices:
|
||||
full_message = last_msg_content + choice.message.content
|
||||
choice.message.content = full_message
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _load_chat_template(self, chat_template):
|
||||
if chat_template is not None:
|
||||
try:
|
||||
with open(chat_template, "r") as f:
|
||||
self.tokenizer.chat_template = f.read()
|
||||
except OSError:
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
self.tokenizer.chat_template = codecs.decode(
|
||||
chat_template, "unicode_escape")
|
||||
|
||||
logger.info(
|
||||
f"Using supplied chat template:\n{self.tokenizer.chat_template}"
|
||||
)
|
||||
elif self.tokenizer.chat_template is not None:
|
||||
logger.info(
|
||||
f"Using default chat template:\n{self.tokenizer.chat_template}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"No chat template provided. Chat API will not work.")
|
295
vllm/entrypoints/openai/serving_completion.py
Normal file
295
vllm/entrypoints/openai/serving_completion.py
Normal file
@ -0,0 +1,295 @@
|
||||
import time
|
||||
from fastapi import Request
|
||||
from typing import AsyncGenerator, Optional
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from .protocol import (CompletionRequest, CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse, LogProbs, UsageInfo)
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(self, engine: AsyncLLMEngine, served_model: str):
|
||||
super().__init__(engine=engine, served_model=served_model)
|
||||
|
||||
async def create_completion(self, request: CompletionRequest,
|
||||
raw_request: Request):
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following features:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
- logit_bias (to be supported by vLLM engine)
|
||||
"""
|
||||
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# OpenAI API supports echoing the prompt when max_tokens is 0.
|
||||
echo_without_generation = request.echo and request.max_tokens == 0
|
||||
|
||||
if request.suffix is not None:
|
||||
# The language models we currently support do not support suffix.
|
||||
return self.create_error_response(
|
||||
"suffix is not currently supported")
|
||||
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
# TODO: support logit_bias in vLLM engine.
|
||||
return self.create_error_response(
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
|
||||
use_token_ids = False
|
||||
if isinstance(request.prompt, list):
|
||||
if len(request.prompt) == 0:
|
||||
return self.create_error_response(
|
||||
"please provide at least one prompt")
|
||||
first_element = request.prompt[0]
|
||||
if isinstance(first_element, int):
|
||||
use_token_ids = True
|
||||
prompt = request.prompt
|
||||
elif isinstance(first_element, (str, list)):
|
||||
# TODO: handles multiple prompt case in list[list[int]]
|
||||
if len(request.prompt) > 1:
|
||||
return self.create_error_response(
|
||||
"multiple prompts in a batch is not currently supported"
|
||||
)
|
||||
use_token_ids = not isinstance(first_element, str)
|
||||
prompt = request.prompt[0]
|
||||
else:
|
||||
prompt = request.prompt
|
||||
|
||||
if use_token_ids:
|
||||
_, error_check_ret = await self._check_length(request,
|
||||
prompt_ids=prompt)
|
||||
else:
|
||||
token_ids, error_check_ret = await self._check_length(
|
||||
request, prompt=prompt)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
created_time = int(time.monotonic())
|
||||
try:
|
||||
spaces_between_special_tokens = request.spaces_between_special_tokens
|
||||
sampling_params = SamplingParams(
|
||||
n=request.n,
|
||||
best_of=request.best_of,
|
||||
presence_penalty=request.presence_penalty,
|
||||
frequency_penalty=request.frequency_penalty,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
min_p=request.min_p,
|
||||
stop=request.stop,
|
||||
stop_token_ids=request.stop_token_ids,
|
||||
ignore_eos=request.ignore_eos,
|
||||
max_tokens=request.max_tokens
|
||||
if not echo_without_generation else 1,
|
||||
logprobs=request.logprobs,
|
||||
use_beam_search=request.use_beam_search,
|
||||
prompt_logprobs=request.logprobs if request.echo else None,
|
||||
skip_special_tokens=request.skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if use_token_ids:
|
||||
result_generator = self.engine.generate(None,
|
||||
sampling_params,
|
||||
request_id,
|
||||
prompt_token_ids=prompt)
|
||||
else:
|
||||
result_generator = self.engine.generate(prompt, sampling_params,
|
||||
request_id, token_ids)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use beam search.
|
||||
stream = (request.stream
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
|
||||
def create_stream_response_json(
|
||||
index: int,
|
||||
text: str,
|
||||
logprobs: Optional[LogProbs] = None,
|
||||
finish_reason: Optional[str] = None,
|
||||
usage: Optional[UsageInfo] = None,
|
||||
) -> str:
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
text=text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
response = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[choice_data],
|
||||
)
|
||||
if usage is not None:
|
||||
response.usage = usage
|
||||
response_json = response.json(exclude_unset=True,
|
||||
ensure_ascii=False)
|
||||
|
||||
return response_json
|
||||
|
||||
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
has_echoed = [False] * request.n
|
||||
async for res in result_generator:
|
||||
res: RequestOutput
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
token_ids = output.token_ids[previous_num_tokens[i]:]
|
||||
if request.logprobs is not None:
|
||||
top_logprobs = output.logprobs[previous_num_tokens[i]:]
|
||||
else:
|
||||
top_logprobs = None
|
||||
offsets = len(previous_texts[i])
|
||||
if request.echo and not has_echoed[i]:
|
||||
if not echo_without_generation:
|
||||
delta_text = res.prompt + delta_text
|
||||
token_ids = res.prompt_token_ids + token_ids
|
||||
if top_logprobs:
|
||||
top_logprobs = res.prompt_logprobs + top_logprobs
|
||||
else: # only just return the prompt
|
||||
delta_text = res.prompt
|
||||
token_ids = res.prompt_token_ids
|
||||
if top_logprobs:
|
||||
top_logprobs = res.prompt_logprobs
|
||||
has_echoed[i] = True
|
||||
if request.logprobs is not None:
|
||||
logprobs = self._create_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
initial_text_offset=offsets,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
finish_reason = output.finish_reason
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
if output.finish_reason is not None:
|
||||
logprobs = (LogProbs()
|
||||
if request.logprobs is not None else None)
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
completion_tokens = len(output.token_ids)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text="",
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
usage=final_usage,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
return completion_stream_generator()
|
||||
|
||||
# Non-streaming response
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(request_id)
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
choices = []
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
prompt_logprobs = final_res.prompt_logprobs
|
||||
prompt_text = final_res.prompt
|
||||
for output in final_res.outputs:
|
||||
if request.logprobs is not None:
|
||||
if not echo_without_generation:
|
||||
token_ids = output.token_ids
|
||||
top_logprobs = output.logprobs
|
||||
if request.echo:
|
||||
token_ids = prompt_token_ids + token_ids
|
||||
top_logprobs = prompt_logprobs + top_logprobs
|
||||
else:
|
||||
token_ids = prompt_token_ids
|
||||
top_logprobs = prompt_logprobs
|
||||
logprobs = self._create_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
if not echo_without_generation:
|
||||
output_text = output.text
|
||||
if request.echo:
|
||||
output_text = prompt_text + output_text
|
||||
else:
|
||||
output_text = prompt_text
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=output.index,
|
||||
text=output_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = CompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
response_json = response.json(ensure_ascii=False)
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return fake_stream_generator()
|
||||
|
||||
return response
|
130
vllm/entrypoints/openai/serving_engine.py
Normal file
130
vllm/entrypoints/openai/serving_engine.py
Normal file
@ -0,0 +1,130 @@
|
||||
import asyncio
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||
ChatCompletionRequest,
|
||||
ErrorResponse, LogProbs,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServing:
|
||||
|
||||
def __init__(self, engine: AsyncLLMEngine, served_model: str):
|
||||
self.engine = engine
|
||||
self.served_model = served_model
|
||||
|
||||
self.max_model_len = 0
|
||||
self.tokenizer = None
|
||||
|
||||
try:
|
||||
event_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
event_loop = None
|
||||
|
||||
if event_loop is not None and event_loop.is_running(
|
||||
): # If the current is instanced by Ray Serve, there is already a running event loop
|
||||
event_loop.create_task(self._post_init())
|
||||
else: # When using single vLLM without engine_use_ray
|
||||
asyncio.run(self._post_init())
|
||||
|
||||
async def _post_init(self):
|
||||
engine_model_config = await self.engine.get_model_config()
|
||||
self.max_model_len = engine_model_config.max_model_len
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
self.tokenizer = get_tokenizer(
|
||||
engine_model_config.tokenizer,
|
||||
tokenizer_mode=engine_model_config.tokenizer_mode,
|
||||
trust_remote_code=engine_model_config.trust_remote_code)
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. Right now we only have one model."""
|
||||
model_cards = [
|
||||
ModelCard(id=self.served_model,
|
||||
root=self.served_model,
|
||||
permission=[ModelPermission()])
|
||||
]
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
def _create_logprobs(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
initial_text_offset: int = 0,
|
||||
) -> LogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs = []
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is not None:
|
||||
token_logprob = step_top_logprobs[token_id]
|
||||
else:
|
||||
token_logprob = None
|
||||
token = self.tokenizer.convert_ids_to_tokens(token_id)
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
||||
last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append({
|
||||
self.tokenizer.convert_ids_to_tokens(i): p
|
||||
for i, p in step_top_logprobs.items()
|
||||
} if step_top_logprobs else None)
|
||||
return logprobs
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
|
||||
return ErrorResponse(message=message,
|
||||
type=err_type,
|
||||
code=status_code.value)
|
||||
|
||||
async def _check_model(self, request) -> Optional[ErrorResponse]:
|
||||
if request.model == self.served_model:
|
||||
return
|
||||
return self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
async def _check_length(
|
||||
self,
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None
|
||||
) -> Tuple[List[int], Optional[ErrorResponse]]:
|
||||
assert (not (prompt is None and prompt_ids is None)
|
||||
and not (prompt is not None and prompt_ids is not None)
|
||||
), "Either prompt or prompt_ids should be provided."
|
||||
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
|
||||
prompt).input_ids
|
||||
token_num = len(input_ids)
|
||||
|
||||
if request.max_tokens is None:
|
||||
request.max_tokens = self.max_model_len - token_num
|
||||
if token_num + request.max_tokens > self.max_model_len:
|
||||
return input_ids, self.create_error_response(
|
||||
f"This model's maximum context length is {self.max_model_len} tokens. "
|
||||
f"However, you requested {request.max_tokens + token_num} tokens "
|
||||
f"({token_num} in the messages, "
|
||||
f"{request.max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.", )
|
||||
else:
|
||||
return input_ids, None
|
Loading…
x
Reference in New Issue
Block a user