# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py import argparse from http import HTTPStatus import json import time from typing import AsyncGenerator, Dict, List, Optional import fastapi from fastapi import BackgroundTasks, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse import uvicorn from cacheflow.outputs import RequestOutput from cacheflow.server.arg_utils import AsyncServerArgs from cacheflow.server.async_llm_server import AsyncLLMServer from cacheflow.server.tokenizer_utils import get_tokenizer from cacheflow.logger import init_logger from cacheflow.sampling_params import SamplingParams from cacheflow.utils import random_uuid from cacheflow.entrypoints.openai.protocol import ( CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo, ) TIMEOUT_KEEP_ALIVE = 5 # seconds logger = init_logger(__name__) served_model = None app = fastapi.FastAPI() def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse: return JSONResponse( ErrorResponse(message=message, type="invalid_request_error").dict(), status_code=status_code.value ) @app.exception_handler(RequestValidationError) async def validation_exception_handler(request, exc): return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) async def check_model(request) -> Optional[JSONResponse]: if request.model == served_model: return ret = create_error_response( HTTPStatus.NOT_FOUND, f"The model `{request.model}` does not exist.", ) return ret @app.get("/v1/models") async def show_available_models(): """Show available models. Right now we only have one model.""" model_cards = [ModelCard(id=served_model, root=served_model, permission=[ModelPermission()])] return ModelList(data=model_cards) def create_logprobs(token_ids: List[int], id_logprobs: List[Dict[int, float]], initial_text_offset: int = 0) -> LogProbs: """Create OpenAI-style logprobs.""" logprobs = LogProbs() last_token_len = 0 for token_id, id_logprob in zip(token_ids, id_logprobs): token = tokenizer.convert_ids_to_tokens(token_id) logprobs.tokens.append(token) logprobs.token_logprobs.append(id_logprob[token_id]) if len(logprobs.text_offset) == 0: logprobs.text_offset.append(initial_text_offset) else: logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) last_token_len = len(token) logprobs.top_logprobs.append( {tokenizer.convert_ids_to_tokens(i): p for i, p in id_logprob.items()}) return logprobs @app.post("/v1/completions") async def create_completion(raw_request: Request): request = CompletionRequest(**await raw_request.json()) logger.info(f"Received completion request: {request}") error_check_ret = await check_model(request) if error_check_ret is not None: return error_check_ret if request.echo: # We do not support echo since the cacheflow server does not # currently support getting the logprobs of prompt tokens. return create_error_response(HTTPStatus.BAD_REQUEST, "echo is not currently supported") if request.suffix is not None: # The language models we currently support do not support suffix. return create_error_response(HTTPStatus.BAD_REQUEST, "suffix is not currently supported") if request.logit_bias is not None: # TODO: support logit_bias in cacheflow server. return create_error_response(HTTPStatus.BAD_REQUEST, "logit_bias is not currently supported") model_name = request.model request_id = f"cmpl-{random_uuid()}" prompt = request.prompt created_time = int(time.time()) try: sampling_params = SamplingParams( n=request.n, best_of=request.best_of, presence_penalty=request.presence_penalty, frequency_penalty=request.frequency_penalty, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, stop=request.stop, ignore_eos=request.ignore_eos, max_tokens=request.max_tokens, logprobs=request.logprobs, use_beam_search=request.use_beam_search, ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) result_generator = server.generate(prompt, sampling_params, request_id) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use beam search. stream = (request.stream and (request.best_of is None or request.n == request.best_of) and not request.use_beam_search) async def abort_request() -> None: await server.abort(request_id) def create_stream_response_json(index: int, text: str, logprobs: Optional[LogProbs] = None, finish_reason: Optional[str] = None) -> str: choice_data = CompletionResponseStreamChoice( index=index, text=text, logprobs=logprobs, finish_reason=finish_reason, ) response = CompletionStreamResponse( id=request_id, created=created_time, model=model_name, choices=[choice_data], ) response_json = response.json(ensure_ascii=False) return response_json async def completion_stream_generator() -> AsyncGenerator[str, None]: previous_texts = [""] * request.n previous_num_tokens = [0] * request.n async for res in result_generator: res: RequestOutput for output in res.outputs: i = output.index delta_text = output.text[len(previous_texts[i]):] if request.logprobs is not None: logprobs = create_logprobs( output.token_ids[previous_num_tokens[i]:], output.logprobs[previous_num_tokens[i]:], len(previous_texts[i])) else: logprobs = None previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) response_json = create_stream_response_json( index=i, text=delta_text, logprobs=logprobs, ) yield f"data: {response_json}\n\n" if output.finish_reason is not None: logprobs = LogProbs() if request.logprobs is not None else None response_json = create_stream_response_json( index=i, text="", logprobs=logprobs, finish_reason=output.finish_reason, ) yield f"data: {response_json}\n\n" yield "data: [DONE]\n\n" # Streaming response if stream: background_tasks = BackgroundTasks() # Abort the request if the client disconnects. background_tasks.add_task(abort_request) return StreamingResponse(completion_stream_generator(), media_type="text/event-stream", background=background_tasks) # Non-streaming response final_res: RequestOutput = None async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. await server.abort(request_id) return create_error_response(HTTPStatus.BAD_REQUEST, "Client disconnected") final_res = res assert final_res is not None choices = [] for output in final_res.outputs: if request.logprobs is not None: logprobs = create_logprobs(output.token_ids, output.logprobs) else: logprobs = None choice_data = CompletionResponseChoice( index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason, ) choices.append(choice_data) num_prompt_tokens = len(final_res.prompt_token_ids) num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, ) response = CompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, ) if request.stream: # When user requests streaming but we don't stream, we still need to # return a streaming response with a single event. response_json = response.json(ensure_ascii=False) async def fake_stream_generator() -> AsyncGenerator[str, None]: yield f"data: {response_json}\n\n" yield "data: [DONE]\n\n" return StreamingResponse(fake_stream_generator(), media_type="text/event-stream") return response if __name__ == "__main__": parser = argparse.ArgumentParser( description="CacheFlow OpenAI-Compatible RESTful API server." ) parser.add_argument("--host", type=str, default="localhost", help="host name") parser.add_argument("--port", type=int, default=8000, help="port number") parser.add_argument( "--allow-credentials", action="store_true", help="allow credentials" ) parser.add_argument( "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" ) parser.add_argument( "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" ) parser.add_argument( "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" ) parser.add_argument("--served-model-name", type=str, default=None, help="The model name used in the API. If not specified, " "the model name will be the same as the " "huggingface name.") parser = AsyncServerArgs.add_cli_args(parser) args = parser.parse_args() app.add_middleware( CORSMiddleware, allow_origins=args.allowed_origins, allow_credentials=args.allow_credentials, allow_methods=args.allowed_methods, allow_headers=args.allowed_headers, ) logger.info(f"args: {args}") served_model = args.served_model_name or args.model server_args = AsyncServerArgs.from_cli_args(args) server = AsyncLLMServer.from_server_args(server_args) # A separate tokenizer to map token IDs to strings. tokenizer = get_tokenizer(args.model) uvicorn.run(app, host=args.host, port=args.port, log_level="info", timeout_keep_alive=TIMEOUT_KEEP_ALIVE)