OpenAI Compatible Frontend (#116)

This commit is contained in:
Zhuohan Li 2023-05-23 21:39:50 -07:00 committed by GitHub
parent e86717833d
commit 057daef778
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 644 additions and 169 deletions

View File

@ -148,7 +148,7 @@ class BlockSpaceManager:
# the sequences in the same group. # the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set() blocks: Set[PhysicalTokenBlock] = set()
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED: if SequenceStatus.is_finished(seq.status):
continue continue
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
for block in block_table: for block in block_table:
@ -169,7 +169,7 @@ class BlockSpaceManager:
# CPU block -> GPU block. # CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED: if SequenceStatus.is_finished(seq.status):
continue continue
new_block_table: BlockTable = [] new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
@ -200,7 +200,7 @@ class BlockSpaceManager:
# GPU block -> CPU block. # GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED: if SequenceStatus.is_finished(seq.status):
continue continue
new_block_table: BlockTable = [] new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]

View File

@ -292,10 +292,12 @@ class Scheduler:
# Append a new token to the sequence. # Append a new token to the sequence.
output = seq_outputs[seq.seq_id] output = seq_outputs[seq.seq_id]
seq.append_token_id(output.output_token, output.logprobs) seq.append_token_id(output.output_token, output.logprobs)
# Return a shallow copy of the running queue to prevent the queue
# from being modified by the caller.
return self.running.copy() return self.running.copy()
def free_seq(self, seq: Sequence) -> None: def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
seq.status = SequenceStatus.FINISHED seq.status = finish_status
self.block_manager.free(seq) self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None: def free_finished_seq_groups(self) -> None:

View File

@ -0,0 +1,300 @@
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import argparse
from http import HTTPStatus
import json
import time
from typing import AsyncGenerator, Dict, List, Optional
import fastapi
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
from cacheflow.outputs import RequestOutput
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import random_uuid
from cacheflow.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
LogProbs,
ModelCard,
ModelList,
ModelPermission,
UsageInfo,
)
logger = init_logger(__name__)
served_model = None
app = fastapi.FastAPI()
def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse:
return JSONResponse(
ErrorResponse(message=message, type="invalid_request_error").dict(),
status_code=status_code.value
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
async def check_model(request) -> Optional[JSONResponse]:
if request.model == served_model:
return
ret = create_error_response(
HTTPStatus.NOT_FOUND,
f"The model `{request.model}` does not exist.",
)
return ret
@app.get("/v1/models")
async def show_available_models():
"""Show available models. Right now we only have one model."""
model_cards = [ModelCard(id=served_model, root=served_model,
permission=[ModelPermission()])]
return ModelList(data=model_cards)
def create_logprobs(token_ids: List[int],
id_logprobs: List[Dict[int, float]],
initial_text_offset: int = 0) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
for token_id, id_logprob in zip(token_ids, id_logprobs):
token = tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(id_logprob[token_id])
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
last_token_len = len(token)
logprobs.top_logprobs.append(
{tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()})
return logprobs
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
logger.info(f"Received completion request: {request}")
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.echo:
# We do not support echo since the cacheflow server does not
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")
if request.suffix is not None:
# The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported")
if request.logit_bias is not None:
# TODO: support logit_bias in cacheflow server.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
prompt = request.prompt
created_time = int(time.time())
try:
sampling_params = SamplingParams(
n=request.n,
best_of=request.best_of,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = server.generate(prompt, sampling_params,
request_id=request_id)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (request.stream and
(request.best_of is None or request.n == request.best_of) and
not request.use_beam_search)
def create_stream_response_json(index: int,
text: str,
logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None) -> str:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
logprobs=logprobs,
finish_reason=finish_reason,
)
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.json(ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
if request.logprobs is not None:
logprobs = create_logprobs(
output.token_ids[previous_num_tokens[i]:],
output.logprobs[previous_num_tokens[i]:],
len(previous_texts[i]))
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
response_json = create_stream_response_json(
index=i,
text=delta_text,
logprobs=logprobs,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None
response_json = create_stream_response_json(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if stream:
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream")
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
final_res = res
assert final_res is not None
choices = []
for output in final_res.outputs:
if request.logprobs is not None:
logprobs = create_logprobs(output.token_ids, output.logprobs)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=output.index,
text=output.text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids)
for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="CacheFlow OpenAI-Compatible RESTful API server."
)
parser.add_argument("--host", type=str, default="localhost", help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--allow-credentials", action="store_true", help="allow credentials"
)
parser.add_argument(
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
)
parser.add_argument(
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
)
parser.add_argument(
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
)
parser.add_argument("--served-model-name", type=str, default=None,
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
parser = ServerArgs.add_cli_args(parser)
args = parser.parse_args()
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
logger.info(f"args: {args}")
served_model = args.served_model_name or args.model
server_args = ServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View File

@ -0,0 +1,126 @@
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from cacheflow.utils import random_uuid
class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: Optional[str] = None
class ModelPermission(BaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
allow_create_engine: bool = False
allow_sampling: bool = True
allow_logprobs: bool = True
allow_search_indices: bool = False
allow_view: bool = True
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
is_blocking: str = False
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "cacheflow"
root: Optional[str] = None
parent: Optional[str] = None
permission: List[ModelPermission] = Field(default_factory=list)
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, str]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
class CompletionRequest(BaseModel):
model: str
prompt: str
suffix: Optional[str] = None
max_tokens: Optional[int] = 16
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stream: Optional[bool] = False
logprobs: Optional[int] = None
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Additional parameters supported by cacheflow
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
class CompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
class CompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]

View File

@ -0,0 +1,51 @@
import argparse
import json
from typing import AsyncGenerator
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import uvicorn
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.ray_utils import initialize_cluster
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI()
@app.post("/generate")
async def generate_stream(request: Request) -> StreamingResponse:
request_dict = await request.json()
prompt = request_dict.pop("prompt")
sampling_params = SamplingParams(**request_dict)
results_generator = server.generate(prompt, sampling_params)
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
text_outputs = [
prompt + output.text
for output in request_output.outputs
]
ret = {
"text": text_outputs,
"error": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8")
return StreamingResponse(stream_results())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser = ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View File

@ -1,5 +1,5 @@
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
from typing import Dict, List, Tuple from typing import Dict, List, Tuple, Optional
import numpy as np import numpy as np
import torch import torch
@ -258,9 +258,9 @@ def _apply_top_p_top_k(
def _get_topk_logprobs( def _get_topk_logprobs(
logprobs: torch.Tensor, logprobs: torch.Tensor,
num_logprobs: int, num_logprobs: Optional[int],
) -> Dict[int, float]: ) -> Dict[int, float]:
if num_logprobs == 0: if num_logprobs is None or num_logprobs == 0:
return {} return {}
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs) topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)

View File

@ -1,6 +1,6 @@
from typing import Dict, List from typing import Dict, List, Optional
from cacheflow.sequence import SequenceGroup from cacheflow.sequence import SequenceGroup, SequenceStatus
class CompletionOutput: class CompletionOutput:
@ -12,19 +12,25 @@ class CompletionOutput:
token_ids: List[int], token_ids: List[int],
cumulative_logprob: float, cumulative_logprob: float,
logprobs: List[Dict[int, float]], logprobs: List[Dict[int, float]],
finish_reason: Optional[str] = None,
) -> None: ) -> None:
self.index = index self.index = index
self.text = text self.text = text
self.token_ids = token_ids self.token_ids = token_ids
self.cumulative_logprob = cumulative_logprob self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs self.logprobs = logprobs
self.finish_reason = finish_reason
def finished(self) -> bool:
return self.finish_reason is not None
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"CompletionOutput(index={self.index}, " return (f"CompletionOutput(index={self.index}, "
f"text={self.text!r}, " f"text={self.text!r}, "
f"token_ids={self.token_ids}, " f"token_ids={self.token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, " f"cumulative_logprob={self.cumulative_logprob}, "
f"logprobs={self.logprobs})") f"logprobs={self.logprobs},"
f"finish_reason={self.finish_reason})")
class RequestOutput: class RequestOutput:
@ -35,13 +41,11 @@ class RequestOutput:
prompt: str, prompt: str,
prompt_token_ids: List[int], prompt_token_ids: List[int],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
done: bool,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.outputs = outputs self.outputs = outputs
self.done = done
@classmethod @classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
@ -57,25 +61,28 @@ class RequestOutput:
outputs: List[CompletionOutput] = [] outputs: List[CompletionOutput] = []
for seq in top_n_seqs: for seq in top_n_seqs:
logprobs = seq.output_logprobs logprobs = seq.output_logprobs
if seq_group.sampling_params.logprobs == 0: if seq_group.sampling_params.logprobs is None:
# NOTE: We need to take care of this case because the sequence # NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the # always has the logprobs of the sampled tokens even if the
# logprobs are not requested. # logprobs are not requested.
logprobs = {} logprobs = {}
finshed_reason = SequenceStatus.get_finished_reason(seq.status)
output = CompletionOutput(seqs.index(seq), seq.output_text, output = CompletionOutput(seqs.index(seq), seq.output_text,
seq.get_output_token_ids(), seq.get_output_token_ids(),
seq.get_cumulative_logprob(), logprobs) seq.get_cumulative_logprob(), logprobs,
finshed_reason)
outputs.append(output) outputs.append(output)
# Every sequence in the sequence group should have the same prompt. # Every sequence in the sequence group should have the same prompt.
prompt = top_n_seqs[0].prompt prompt = top_n_seqs[0].prompt
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs, return cls(seq_group.request_id, prompt, prompt_token_ids, outputs)
seq_group.is_finished())
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, " return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, " f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
f"outputs={self.outputs}, " f"outputs={self.outputs})")
f"done={self.done})")
def finished(self) -> bool:
return all(output.finished() for output in self.outputs)

View File

@ -53,7 +53,7 @@ class SamplingParams:
stop: Union[str, List[str]] = [], stop: Union[str, List[str]] = [],
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: int = 16, max_tokens: int = 16,
logprobs: int = 0, logprobs: Optional[int] = None,
) -> None: ) -> None:
self.n = n self.n = n
self.best_of = best_of if best_of is not None else n self.best_of = best_of if best_of is not None else n
@ -98,7 +98,7 @@ class SamplingParams:
if self.max_tokens < 1: if self.max_tokens < 1:
raise ValueError( raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.") f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.logprobs < 0: if self.logprobs is not None and self.logprobs < 0:
raise ValueError( raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.") f"logprobs must be non-negative, got {self.logprobs}.")

View File

@ -1,6 +1,6 @@
import copy import copy
import enum import enum
from typing import Dict, List, Optional from typing import Dict, List, Optional, Union
from cacheflow.block import LogicalTokenBlock from cacheflow.block import LogicalTokenBlock
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
@ -10,8 +10,25 @@ class SequenceStatus(enum.Enum):
WAITING = enum.auto() WAITING = enum.auto()
RUNNING = enum.auto() RUNNING = enum.auto()
SWAPPED = enum.auto() SWAPPED = enum.auto()
FINISHED = enum.auto() FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto()
@staticmethod
def is_finished(status: "SequenceStatus") -> bool:
return status in [
SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED,
]
@staticmethod
def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
if status == SequenceStatus.FINISHED_STOPPED:
finish_reason = "stop"
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
finish_reason = "length"
else:
finish_reason = None
return finish_reason
class SequenceData: class SequenceData:
@ -20,7 +37,6 @@ class SequenceData:
prompt_token_ids: List[int], prompt_token_ids: List[int],
) -> None: ) -> None:
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.output_token_ids: List[int] = [] self.output_token_ids: List[int] = []
self.cumulative_logprob = 0.0 self.cumulative_logprob = 0.0
@ -166,7 +182,7 @@ class SequenceGroup:
raise ValueError(f'Sequence {seq_id} not found.') raise ValueError(f'Sequence {seq_id} not found.')
def is_finished(self) -> bool: def is_finished(self) -> bool:
return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs) return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, " return (f"SequenceGroup(request_id={self.request_id}, "

View File

@ -1,26 +1,20 @@
import argparse
import asyncio import asyncio
import json
import time import time
from typing import Any, Dict from typing import Dict, Optional
import uuid
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import ray import ray
import uvicorn
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster from cacheflow.server.ray_utils import initialize_cluster
from cacheflow.utils import random_uuid
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI()
class FastAPIServer: class AsyncLLMServer:
def __init__(self, server_use_ray: bool, *args, **kwargs) -> None: def __init__(self, server_use_ray: bool, *args, **kwargs) -> None:
if server_use_ray: if server_use_ray:
@ -45,15 +39,15 @@ class FastAPIServer:
self.request_outputs[request_id] = request_output self.request_outputs[request_id] = request_output
self.request_events[request_id].set() self.request_events[request_id].set()
async def generate(self, request_dict: Dict[str, Any]): async def generate(self, prompt: str, sampling_params: SamplingParams,
request_id: Optional[str] = None) -> RequestOutput:
# Preprocess the request. # Preprocess the request.
arrival_time = time.time() arrival_time = time.time()
prompt = request_dict.pop("prompt")
sampling_params = SamplingParams(**request_dict)
# Create an event to notify us that there is new output from the # Create an event to notify us that there is new output from the
# cacheflow server. # cacheflow server.
request_id = str(uuid.uuid4().hex[:8]) if request_id is None:
request_id = random_uuid()
request_event = asyncio.Event() request_event = asyncio.Event()
self.request_events[request_id] = request_event self.request_events[request_id] = request_event
@ -82,19 +76,10 @@ class FastAPIServer:
# Decode and return new outputs. # Decode and return new outputs.
request_output = self.request_outputs[request_id] request_output = self.request_outputs[request_id]
prompt = request_output.prompt yield request_output
text_outputs = [
prompt + output.text
for output in request_output.outputs
]
ret = {
"text": text_outputs,
"error": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8")
# Once finished, release the resources of the sequence group. # Once finished, release the resources of the sequence group.
if request_output.done: if request_output.finished():
del self.request_outputs[request_id] del self.request_outputs[request_id]
del self.request_events[request_id] del self.request_events[request_id]
# Kick the server if the server is not running. This is to # Kick the server if the server is not running. This is to
@ -104,25 +89,15 @@ class FastAPIServer:
await self.server_step() await self.server_step()
break break
@classmethod
@app.post("/generate") def from_server_args(cls, server_args: ServerArgs) -> "AsyncLLMServer":
async def generate_stream(request: Request): # Create the server configs.
request_dict = await request.json() server_configs = server_args.create_server_configs()
return StreamingResponse(server.generate(request_dict)) parallel_config = server_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
if __name__ == "__main__": # Create the LLM server.
parser = argparse.ArgumentParser() server = cls(server_args.use_ray, *server_configs,
parser.add_argument("--host", type=str, default="localhost") distributed_init_method, devices,
parser.add_argument("--port", type=int, default=10002) log_stats=not server_args.disable_log_stats)
parser = ServerArgs.add_cli_args(parser) return server
args = parser.parse_args()
server_configs = ServerArgs.from_cli_args(args).create_server_configs()
parallel_config = server_configs[2]
distributed_init_method, stage_devices = initialize_cluster(parallel_config)
server = FastAPIServer(args.use_ray, *server_configs,
distributed_init_method, stage_devices,
log_stats=not args.disable_log_stats)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View File

@ -210,7 +210,8 @@ class LLMServer:
# Truncate the output text so that the stop string is # Truncate the output text so that the stop string is
# not included in the output. # not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)] seq.output_text = seq.output_text[:-len(stop_str)]
self.scheduler.free_seq(seq) self.scheduler.free_seq(seq,
SequenceStatus.FINISHED_STOPPED)
stopped = True stopped = True
break break
if stopped: if stopped:
@ -218,12 +219,14 @@ class LLMServer:
# Check if the sequence has reached max_tokens. # Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens: if seq.get_output_len() == sampling_params.max_tokens:
self.scheduler.free_seq(seq) self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue continue
# Check if the sequence has generated the EOS token. # Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos: if not sampling_params.ignore_eos:
if seq.get_last_token_id() == self.tokenizer.eos_token_id: if seq.get_last_token_id() == self.tokenizer.eos_token_id:
self.scheduler.free_seq(seq) self.scheduler.free_seq(seq,
SequenceStatus.FINISHED_STOPPED)
continue continue
def _run_workers( def _run_workers(

View File

@ -1,4 +1,5 @@
import enum import enum
import uuid
import psutil import psutil
import torch import torch
@ -31,3 +32,7 @@ def get_gpu_memory(gpu: int = 0) -> int:
def get_cpu_memory() -> int: def get_cpu_memory() -> int:
"""Returns the total CPU memory of the node in bytes.""" """Returns the total CPU memory of the node in bytes."""
return psutil.virtual_memory().total return psutil.virtual_memory().total
def random_uuid() -> str:
return str(uuid.uuid4().hex)

View File

@ -1,6 +1,5 @@
import argparse import argparse
import json import json
import time
import gradio as gr import gradio as gr
import requests import requests
@ -24,9 +23,9 @@ def http_bot(prompt):
def build_demo(): def build_demo():
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown( gr.Markdown(
"# Cacheflow demo\n" "# Cacheflow text completion demo\n"
) )
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")# .style(container=False) inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model") outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
inputbox.submit(http_bot, [inputbox], [outputbox]) inputbox.submit(http_bot, [inputbox], [outputbox])
return demo return demo
@ -35,9 +34,11 @@ def build_demo():
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=10003) parser.add_argument("--port", type=int, default=8002)
parser.add_argument("--model-url", type=str, default="http://localhost:10002/generate") parser.add_argument("--model-url", type=str, default="http://localhost:8001/generate")
args = parser.parse_args() args = parser.parse_args()
demo = build_demo() demo = build_demo()
demo.queue(concurrency_count=100).launch(server_name=args.host, server_port=args.port) demo.queue(concurrency_count=100).launch(server_name=args.host,
server_port=args.port,
share=True)

22
examples/openai_client.py Normal file
View File

@ -0,0 +1,22 @@
import openai
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
model = "facebook/opt-125m"
# list models
models = openai.Model.list()
print(models)
# create a completion
stream = True
completion = openai.Completion.create(
model=model, prompt="A robot may not injure a human being", echo=False, n=2,
best_of=3, stream=stream, logprobs=3)
# print the completion
if stream:
for c in completion:
print(c)
else:
print("completion:", completion)

View File

@ -0,0 +1,48 @@
import argparse
import requests
import json
def clear_line(n=1):
LINE_UP = '\033[1A'
LINE_CLEAR = '\x1b[2K'
for i in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True)
def http_request(prompt: str, api_url: str, n: int = 1):
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"n": n,
"use_beam_search": True,
"temperature": 0.0,
"max_tokens": 16,
}
response = requests.post(api_url, headers=headers, json=pload, stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--n", type=int, default=4)
parser.add_argument("--prompt", type=str, default="San Francisco is a")
args = parser.parse_args()
prompt = args.prompt
api_url = f"http://{args.host}:{args.port}/generate"
n = args.n
print(f"Prompt: {prompt}\n", flush=True)
num_printed_lines = 0
for h in http_request(prompt, api_url, n):
clear_line(num_printed_lines)
num_printed_lines = 0
for i, line in enumerate(h):
num_printed_lines += 1
print(f"Beam candidate {i}: {line}", flush=True)

View File

@ -1,5 +1,4 @@
import argparse import argparse
import uuid
from cacheflow import ServerArgs, LLMServer, SamplingParams from cacheflow import ServerArgs, LLMServer, SamplingParams
@ -20,17 +19,19 @@ def main(args: argparse.Namespace):
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)), SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
] ]
request_id = 0
# Run the server. # Run the server.
while True: while True:
# To test iteration-level scheduling, we add one request at each step. # To test iteration-level scheduling, we add one request at each step.
if test_prompts: if test_prompts:
prompt, sampling_params = test_prompts.pop(0) prompt, sampling_params = test_prompts.pop(0)
request_id = str(uuid.uuid4().hex[:8]) server.add_request(str(request_id), prompt, sampling_params)
server.add_request(request_id, prompt, sampling_params) request_id += 1
request_outputs = server.step() request_outputs = server.step()
for request_output in request_outputs: for request_output in request_outputs:
if request_output.done: if request_output.finished():
print(request_output) print(request_output)
if not (server.has_unfinished_requests() or test_prompts): if not (server.has_unfinished_requests() or test_prompts):

View File

@ -1,20 +0,0 @@
import requests
import json
def http_bot():
prompt = "How are you? I'm fine."
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
}
response = requests.post("http://localhost:10002", headers=headers, json=pload, stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output
for h in http_bot():
print(h, end="", flush=True)

View File

@ -1,40 +0,0 @@
import argparse
import asyncio
import time
from typing import Union
import json
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import uvicorn
app = FastAPI()
async def text_streamer(args):
context = args["prompt"]
words = context.split(" ")
for word in words:
await asyncio.sleep(1)
print("word:", word)
ret = {
"text": word + " ",
"error": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8")
@app.post("/")
async def read_root(request: Request):
args = await request.json()
return StreamingResponse(text_streamer(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=10002)
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View File

@ -8,3 +8,4 @@ transformers >= 4.28.0 # Required for LLaMA.
xformers >= 0.0.19 xformers >= 0.0.19
fastapi fastapi
uvicorn uvicorn
pydantic # Required for OpenAI server.

View File

@ -1,23 +0,0 @@
import requests
import json
def http_request():
prompt = "Ion Stoica is a"
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"n": 4,
"use_beam_search": True,
"temperature": 0.0,
}
response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output
for h in http_request():
print(h, flush=True)