vllm/cacheflow/entrypoints/openai/openai_frontend.py

329 lines
12 KiB
Python

# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import argparse
from http import HTTPStatus
import json
import time
from typing import AsyncGenerator, Dict, List, Optional
import fastapi
from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
from cacheflow.outputs import RequestOutput
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import random_uuid
from cacheflow.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
LogProbs,
ModelCard,
ModelList,
ModelPermission,
UsageInfo,
)
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__)
served_model = None
app = fastapi.FastAPI()
def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse:
return JSONResponse(
ErrorResponse(message=message, type="invalid_request_error").dict(),
status_code=status_code.value
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
async def check_model(request) -> Optional[JSONResponse]:
if request.model == served_model:
return
ret = create_error_response(
HTTPStatus.NOT_FOUND,
f"The model `{request.model}` does not exist.",
)
return ret
@app.get("/v1/models")
async def show_available_models():
"""Show available models. Right now we only have one model."""
model_cards = [ModelCard(id=served_model, root=served_model,
permission=[ModelPermission()])]
return ModelList(data=model_cards)
def create_logprobs(token_ids: List[int],
id_logprobs: List[Dict[int, float]],
initial_text_offset: int = 0) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
for token_id, id_logprob in zip(token_ids, id_logprobs):
token = tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(id_logprob[token_id])
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
last_token_len = len(token)
logprobs.top_logprobs.append(
{tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()})
return logprobs
@app.post("/v1/completions")
async def create_completion(raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
- echo (since the cacheflow server does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported in cacheflow server)
"""
request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}")
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.echo:
# We do not support echo since the cacheflow server does not
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")
if request.suffix is not None:
# The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported")
if request.logit_bias is not None:
# TODO: support logit_bias in cacheflow server.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
prompt = request.prompt
created_time = int(time.time())
try:
sampling_params = SamplingParams(
n=request.n,
best_of=request.best_of,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = server.generate(prompt, sampling_params,
request_id)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (request.stream and
(request.best_of is None or request.n == request.best_of) and
not request.use_beam_search)
async def abort_request() -> None:
await server.abort(request_id)
def create_stream_response_json(index: int,
text: str,
logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None) -> str:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
logprobs=logprobs,
finish_reason=finish_reason,
)
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.json(ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
if request.logprobs is not None:
logprobs = create_logprobs(
output.token_ids[previous_num_tokens[i]:],
output.logprobs[previous_num_tokens[i]:],
len(previous_texts[i]))
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
response_json = create_stream_response_json(
index=i,
text=delta_text,
logprobs=logprobs,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None
response_json = create_stream_response_json(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream",
background=background_tasks)
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await server.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
for output in final_res.outputs:
if request.logprobs is not None:
logprobs = create_logprobs(output.token_ids, output.logprobs)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=output.index,
text=output.text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids)
for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="CacheFlow OpenAI-Compatible RESTful API server."
)
parser.add_argument("--host", type=str, default="localhost", help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--allow-credentials", action="store_true", help="allow credentials"
)
parser.add_argument(
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
)
parser.add_argument(
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
)
parser.add_argument(
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
)
parser.add_argument("--served-model-name", type=str, default=None,
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
parser = AsyncServerArgs.add_cli_args(parser)
args = parser.parse_args()
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
logger.info(f"args: {args}")
served_model = args.served_model_name or args.model
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model)
uvicorn.run(app, host=args.host, port=args.port, log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)