OpenAI Compatible Frontend (#116)
This commit is contained in:
parent
e86717833d
commit
057daef778
@ -148,7 +148,7 @@ class BlockSpaceManager:
|
|||||||
# the sequences in the same group.
|
# the sequences in the same group.
|
||||||
blocks: Set[PhysicalTokenBlock] = set()
|
blocks: Set[PhysicalTokenBlock] = set()
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs():
|
||||||
if seq.status == SequenceStatus.FINISHED:
|
if SequenceStatus.is_finished(seq.status):
|
||||||
continue
|
continue
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
for block in block_table:
|
for block in block_table:
|
||||||
@ -169,7 +169,7 @@ class BlockSpaceManager:
|
|||||||
# CPU block -> GPU block.
|
# CPU block -> GPU block.
|
||||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs():
|
||||||
if seq.status == SequenceStatus.FINISHED:
|
if SequenceStatus.is_finished(seq.status):
|
||||||
continue
|
continue
|
||||||
new_block_table: BlockTable = []
|
new_block_table: BlockTable = []
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
@ -200,7 +200,7 @@ class BlockSpaceManager:
|
|||||||
# GPU block -> CPU block.
|
# GPU block -> CPU block.
|
||||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs():
|
||||||
if seq.status == SequenceStatus.FINISHED:
|
if SequenceStatus.is_finished(seq.status):
|
||||||
continue
|
continue
|
||||||
new_block_table: BlockTable = []
|
new_block_table: BlockTable = []
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
|
@ -292,10 +292,12 @@ class Scheduler:
|
|||||||
# Append a new token to the sequence.
|
# Append a new token to the sequence.
|
||||||
output = seq_outputs[seq.seq_id]
|
output = seq_outputs[seq.seq_id]
|
||||||
seq.append_token_id(output.output_token, output.logprobs)
|
seq.append_token_id(output.output_token, output.logprobs)
|
||||||
|
# Return a shallow copy of the running queue to prevent the queue
|
||||||
|
# from being modified by the caller.
|
||||||
return self.running.copy()
|
return self.running.copy()
|
||||||
|
|
||||||
def free_seq(self, seq: Sequence) -> None:
|
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
|
||||||
seq.status = SequenceStatus.FINISHED
|
seq.status = finish_status
|
||||||
self.block_manager.free(seq)
|
self.block_manager.free(seq)
|
||||||
|
|
||||||
def free_finished_seq_groups(self) -> None:
|
def free_finished_seq_groups(self) -> None:
|
||||||
|
300
cacheflow/entrypoints/openai/openai_frontend.py
Normal file
300
cacheflow/entrypoints/openai/openai_frontend.py
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from http import HTTPStatus
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from cacheflow.outputs import RequestOutput
|
||||||
|
from cacheflow.server.arg_utils import ServerArgs
|
||||||
|
from cacheflow.server.async_llm_server import AsyncLLMServer
|
||||||
|
from cacheflow.server.tokenizer_utils import get_tokenizer
|
||||||
|
from cacheflow.logger import init_logger
|
||||||
|
from cacheflow.sampling_params import SamplingParams
|
||||||
|
from cacheflow.utils import random_uuid
|
||||||
|
from cacheflow.entrypoints.openai.protocol import (
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
CompletionResponseChoice,
|
||||||
|
CompletionResponseStreamChoice,
|
||||||
|
CompletionStreamResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
LogProbs,
|
||||||
|
ModelCard,
|
||||||
|
ModelList,
|
||||||
|
ModelPermission,
|
||||||
|
UsageInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
served_model = None
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
def create_error_response(status_code: HTTPStatus,
|
||||||
|
message: str) -> JSONResponse:
|
||||||
|
return JSONResponse(
|
||||||
|
ErrorResponse(message=message, type="invalid_request_error").dict(),
|
||||||
|
status_code=status_code.value
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(request, exc):
|
||||||
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
|
||||||
|
|
||||||
|
|
||||||
|
async def check_model(request) -> Optional[JSONResponse]:
|
||||||
|
if request.model == served_model:
|
||||||
|
return
|
||||||
|
ret = create_error_response(
|
||||||
|
HTTPStatus.NOT_FOUND,
|
||||||
|
f"The model `{request.model}` does not exist.",
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
async def show_available_models():
|
||||||
|
"""Show available models. Right now we only have one model."""
|
||||||
|
model_cards = [ModelCard(id=served_model, root=served_model,
|
||||||
|
permission=[ModelPermission()])]
|
||||||
|
return ModelList(data=model_cards)
|
||||||
|
|
||||||
|
|
||||||
|
def create_logprobs(token_ids: List[int],
|
||||||
|
id_logprobs: List[Dict[int, float]],
|
||||||
|
initial_text_offset: int = 0) -> LogProbs:
|
||||||
|
"""Create OpenAI-style logprobs."""
|
||||||
|
logprobs = LogProbs()
|
||||||
|
last_token_len = 0
|
||||||
|
for token_id, id_logprob in zip(token_ids, id_logprobs):
|
||||||
|
token = tokenizer.convert_ids_to_tokens(token_id)
|
||||||
|
logprobs.tokens.append(token)
|
||||||
|
logprobs.token_logprobs.append(id_logprob[token_id])
|
||||||
|
if len(logprobs.text_offset) == 0:
|
||||||
|
logprobs.text_offset.append(initial_text_offset)
|
||||||
|
else:
|
||||||
|
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
|
||||||
|
last_token_len = len(token)
|
||||||
|
|
||||||
|
logprobs.top_logprobs.append(
|
||||||
|
{tokenizer.convert_ids_to_tokens(i): p
|
||||||
|
for i, p in id_logprob.items()})
|
||||||
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/completions")
|
||||||
|
async def create_completion(request: CompletionRequest):
|
||||||
|
logger.info(f"Received completion request: {request}")
|
||||||
|
|
||||||
|
error_check_ret = await check_model(request)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
if request.echo:
|
||||||
|
# We do not support echo since the cacheflow server does not
|
||||||
|
# currently support getting the logprobs of prompt tokens.
|
||||||
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
|
"echo is not currently supported")
|
||||||
|
|
||||||
|
if request.suffix is not None:
|
||||||
|
# The language models we currently support do not support suffix.
|
||||||
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
|
"suffix is not currently supported")
|
||||||
|
|
||||||
|
if request.logit_bias is not None:
|
||||||
|
# TODO: support logit_bias in cacheflow server.
|
||||||
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
|
"logit_bias is not currently supported")
|
||||||
|
|
||||||
|
model_name = request.model
|
||||||
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
|
prompt = request.prompt
|
||||||
|
created_time = int(time.time())
|
||||||
|
try:
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
n=request.n,
|
||||||
|
best_of=request.best_of,
|
||||||
|
presence_penalty=request.presence_penalty,
|
||||||
|
frequency_penalty=request.frequency_penalty,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_p=request.top_p,
|
||||||
|
top_k=request.top_k,
|
||||||
|
stop=request.stop,
|
||||||
|
ignore_eos=request.ignore_eos,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
logprobs=request.logprobs,
|
||||||
|
use_beam_search=request.use_beam_search,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
|
|
||||||
|
result_generator = server.generate(prompt, sampling_params,
|
||||||
|
request_id=request_id)
|
||||||
|
|
||||||
|
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||||
|
# results. In addition, we do not stream the results when use beam search.
|
||||||
|
stream = (request.stream and
|
||||||
|
(request.best_of is None or request.n == request.best_of) and
|
||||||
|
not request.use_beam_search)
|
||||||
|
|
||||||
|
def create_stream_response_json(index: int,
|
||||||
|
text: str,
|
||||||
|
logprobs: Optional[LogProbs] = None,
|
||||||
|
finish_reason: Optional[str] = None) -> str:
|
||||||
|
choice_data = CompletionResponseStreamChoice(
|
||||||
|
index=index,
|
||||||
|
text=text,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
)
|
||||||
|
response = CompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created_time,
|
||||||
|
model=model_name,
|
||||||
|
choices=[choice_data],
|
||||||
|
)
|
||||||
|
response_json = response.json(ensure_ascii=False)
|
||||||
|
|
||||||
|
return response_json
|
||||||
|
|
||||||
|
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
||||||
|
previous_texts = [""] * request.n
|
||||||
|
previous_num_tokens = [0] * request.n
|
||||||
|
async for res in result_generator:
|
||||||
|
res: RequestOutput
|
||||||
|
for output in res.outputs:
|
||||||
|
i = output.index
|
||||||
|
delta_text = output.text[len(previous_texts[i]):]
|
||||||
|
if request.logprobs is not None:
|
||||||
|
logprobs = create_logprobs(
|
||||||
|
output.token_ids[previous_num_tokens[i]:],
|
||||||
|
output.logprobs[previous_num_tokens[i]:],
|
||||||
|
len(previous_texts[i]))
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
previous_texts[i] = output.text
|
||||||
|
previous_num_tokens[i] = len(output.token_ids)
|
||||||
|
response_json = create_stream_response_json(
|
||||||
|
index=i,
|
||||||
|
text=delta_text,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
yield f"data: {response_json}\n\n"
|
||||||
|
if output.finish_reason is not None:
|
||||||
|
logprobs = LogProbs() if request.logprobs is not None else None
|
||||||
|
response_json = create_stream_response_json(
|
||||||
|
index=i,
|
||||||
|
text="",
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=output.finish_reason,
|
||||||
|
)
|
||||||
|
yield f"data: {response_json}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
# Streaming response
|
||||||
|
if stream:
|
||||||
|
return StreamingResponse(completion_stream_generator(),
|
||||||
|
media_type="text/event-stream")
|
||||||
|
|
||||||
|
# Non-streaming response
|
||||||
|
final_res: RequestOutput = None
|
||||||
|
async for res in result_generator:
|
||||||
|
final_res = res
|
||||||
|
assert final_res is not None
|
||||||
|
choices = []
|
||||||
|
for output in final_res.outputs:
|
||||||
|
if request.logprobs is not None:
|
||||||
|
logprobs = create_logprobs(output.token_ids, output.logprobs)
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
choice_data = CompletionResponseChoice(
|
||||||
|
index=output.index,
|
||||||
|
text=output.text,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=output.finish_reason,
|
||||||
|
)
|
||||||
|
choices.append(choice_data)
|
||||||
|
|
||||||
|
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||||
|
num_generated_tokens = sum(len(output.token_ids)
|
||||||
|
for output in final_res.outputs)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=num_generated_tokens,
|
||||||
|
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||||
|
)
|
||||||
|
response = CompletionResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created_time,
|
||||||
|
model=model_name,
|
||||||
|
choices=choices,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
# When user requests streaming but we don't stream, we still need to
|
||||||
|
# return a streaming response with a single event.
|
||||||
|
response_json = response.json(ensure_ascii=False)
|
||||||
|
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||||
|
yield f"data: {response_json}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
return StreamingResponse(fake_stream_generator(),
|
||||||
|
media_type="text/event-stream")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="CacheFlow OpenAI-Compatible RESTful API server."
|
||||||
|
)
|
||||||
|
parser.add_argument("--host", type=str, default="localhost", help="host name")
|
||||||
|
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||||
|
parser.add_argument(
|
||||||
|
"--allow-credentials", action="store_true", help="allow credentials"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
|
||||||
|
)
|
||||||
|
parser.add_argument("--served-model-name", type=str, default=None,
|
||||||
|
help="The model name used in the API. If not specified, "
|
||||||
|
"the model name will be the same as the "
|
||||||
|
"huggingface name.")
|
||||||
|
parser = ServerArgs.add_cli_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=args.allowed_origins,
|
||||||
|
allow_credentials=args.allow_credentials,
|
||||||
|
allow_methods=args.allowed_methods,
|
||||||
|
allow_headers=args.allowed_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"args: {args}")
|
||||||
|
|
||||||
|
served_model = args.served_model_name or args.model
|
||||||
|
|
||||||
|
server_args = ServerArgs.from_cli_args(args)
|
||||||
|
server = AsyncLLMServer.from_server_args(server_args)
|
||||||
|
|
||||||
|
# A separate tokenizer to map token IDs to strings.
|
||||||
|
tokenizer = get_tokenizer(args.model)
|
||||||
|
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
126
cacheflow/entrypoints/openai/protocol.py
Normal file
126
cacheflow/entrypoints/openai/protocol.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from cacheflow.utils import random_uuid
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
object: str = "error"
|
||||||
|
message: str
|
||||||
|
type: str
|
||||||
|
param: Optional[str] = None
|
||||||
|
code: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPermission(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
||||||
|
object: str = "model_permission"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
allow_create_engine: bool = False
|
||||||
|
allow_sampling: bool = True
|
||||||
|
allow_logprobs: bool = True
|
||||||
|
allow_search_indices: bool = False
|
||||||
|
allow_view: bool = True
|
||||||
|
allow_fine_tuning: bool = False
|
||||||
|
organization: str = "*"
|
||||||
|
group: Optional[str] = None
|
||||||
|
is_blocking: str = False
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCard(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "model"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
owned_by: str = "cacheflow"
|
||||||
|
root: Optional[str] = None
|
||||||
|
parent: Optional[str] = None
|
||||||
|
permission: List[ModelPermission] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelList(BaseModel):
|
||||||
|
object: str = "list"
|
||||||
|
data: List[ModelCard] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class UsageInfo(BaseModel):
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
completion_tokens: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: List[Dict[str, str]]
|
||||||
|
temperature: Optional[float] = 0.7
|
||||||
|
top_p: Optional[float] = 1.0
|
||||||
|
n: Optional[int] = 1
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
presence_penalty: Optional[float] = 0.0
|
||||||
|
frequency_penalty: Optional[float] = 0.0
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
prompt: str
|
||||||
|
suffix: Optional[str] = None
|
||||||
|
max_tokens: Optional[int] = 16
|
||||||
|
temperature: Optional[float] = 1.0
|
||||||
|
top_p: Optional[float] = 1.0
|
||||||
|
n: Optional[int] = 1
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
logprobs: Optional[int] = None
|
||||||
|
echo: Optional[bool] = False
|
||||||
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
|
presence_penalty: Optional[float] = 0.0
|
||||||
|
frequency_penalty: Optional[float] = 0.0
|
||||||
|
best_of: Optional[int] = None
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
# Additional parameters supported by cacheflow
|
||||||
|
top_k: Optional[int] = -1
|
||||||
|
ignore_eos: Optional[bool] = False
|
||||||
|
use_beam_search: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
class LogProbs(BaseModel):
|
||||||
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||||
|
tokens: List[str] = Field(default_factory=list)
|
||||||
|
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponseChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[LogProbs] = None
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponse(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||||
|
object: str = "text_completion"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[CompletionResponseChoice]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponseStreamChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[LogProbs] = None
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionStreamResponse(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||||
|
object: str = "text_completion"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[CompletionResponseStreamChoice]
|
51
cacheflow/entrypoints/simple_fastapi_frontend.py
Normal file
51
cacheflow/entrypoints/simple_fastapi_frontend.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from cacheflow.sampling_params import SamplingParams
|
||||||
|
from cacheflow.server.arg_utils import ServerArgs
|
||||||
|
from cacheflow.server.async_llm_server import AsyncLLMServer
|
||||||
|
from cacheflow.server.ray_utils import initialize_cluster
|
||||||
|
|
||||||
|
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate")
|
||||||
|
async def generate_stream(request: Request) -> StreamingResponse:
|
||||||
|
request_dict = await request.json()
|
||||||
|
prompt = request_dict.pop("prompt")
|
||||||
|
sampling_params = SamplingParams(**request_dict)
|
||||||
|
results_generator = server.generate(prompt, sampling_params)
|
||||||
|
|
||||||
|
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||||
|
async for request_output in results_generator:
|
||||||
|
prompt = request_output.prompt
|
||||||
|
text_outputs = [
|
||||||
|
prompt + output.text
|
||||||
|
for output in request_output.outputs
|
||||||
|
]
|
||||||
|
ret = {
|
||||||
|
"text": text_outputs,
|
||||||
|
"error": 0,
|
||||||
|
}
|
||||||
|
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||||
|
|
||||||
|
return StreamingResponse(stream_results())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=int, default=8001)
|
||||||
|
parser = ServerArgs.add_cli_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
server_args = ServerArgs.from_cli_args(args)
|
||||||
|
server = AsyncLLMServer.from_server_args(server_args)
|
||||||
|
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
@ -1,5 +1,5 @@
|
|||||||
"""A layer that samples the next tokens from the model's outputs."""
|
"""A layer that samples the next tokens from the model's outputs."""
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -258,9 +258,9 @@ def _apply_top_p_top_k(
|
|||||||
|
|
||||||
def _get_topk_logprobs(
|
def _get_topk_logprobs(
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
num_logprobs: int,
|
num_logprobs: Optional[int],
|
||||||
) -> Dict[int, float]:
|
) -> Dict[int, float]:
|
||||||
if num_logprobs == 0:
|
if num_logprobs is None or num_logprobs == 0:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
|
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from cacheflow.sequence import SequenceGroup
|
from cacheflow.sequence import SequenceGroup, SequenceStatus
|
||||||
|
|
||||||
|
|
||||||
class CompletionOutput:
|
class CompletionOutput:
|
||||||
@ -12,19 +12,25 @@ class CompletionOutput:
|
|||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
cumulative_logprob: float,
|
cumulative_logprob: float,
|
||||||
logprobs: List[Dict[int, float]],
|
logprobs: List[Dict[int, float]],
|
||||||
|
finish_reason: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.index = index
|
self.index = index
|
||||||
self.text = text
|
self.text = text
|
||||||
self.token_ids = token_ids
|
self.token_ids = token_ids
|
||||||
self.cumulative_logprob = cumulative_logprob
|
self.cumulative_logprob = cumulative_logprob
|
||||||
self.logprobs = logprobs
|
self.logprobs = logprobs
|
||||||
|
self.finish_reason = finish_reason
|
||||||
|
|
||||||
|
def finished(self) -> bool:
|
||||||
|
return self.finish_reason is not None
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"CompletionOutput(index={self.index}, "
|
return (f"CompletionOutput(index={self.index}, "
|
||||||
f"text={self.text!r}, "
|
f"text={self.text!r}, "
|
||||||
f"token_ids={self.token_ids}, "
|
f"token_ids={self.token_ids}, "
|
||||||
f"cumulative_logprob={self.cumulative_logprob}, "
|
f"cumulative_logprob={self.cumulative_logprob}, "
|
||||||
f"logprobs={self.logprobs})")
|
f"logprobs={self.logprobs},"
|
||||||
|
f"finish_reason={self.finish_reason})")
|
||||||
|
|
||||||
|
|
||||||
class RequestOutput:
|
class RequestOutput:
|
||||||
@ -35,13 +41,11 @@ class RequestOutput:
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
prompt_token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
outputs: List[CompletionOutput],
|
outputs: List[CompletionOutput],
|
||||||
done: bool,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.prompt_token_ids = prompt_token_ids
|
self.prompt_token_ids = prompt_token_ids
|
||||||
self.outputs = outputs
|
self.outputs = outputs
|
||||||
self.done = done
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
||||||
@ -57,25 +61,28 @@ class RequestOutput:
|
|||||||
outputs: List[CompletionOutput] = []
|
outputs: List[CompletionOutput] = []
|
||||||
for seq in top_n_seqs:
|
for seq in top_n_seqs:
|
||||||
logprobs = seq.output_logprobs
|
logprobs = seq.output_logprobs
|
||||||
if seq_group.sampling_params.logprobs == 0:
|
if seq_group.sampling_params.logprobs is None:
|
||||||
# NOTE: We need to take care of this case because the sequence
|
# NOTE: We need to take care of this case because the sequence
|
||||||
# always has the logprobs of the sampled tokens even if the
|
# always has the logprobs of the sampled tokens even if the
|
||||||
# logprobs are not requested.
|
# logprobs are not requested.
|
||||||
logprobs = {}
|
logprobs = {}
|
||||||
|
finshed_reason = SequenceStatus.get_finished_reason(seq.status)
|
||||||
output = CompletionOutput(seqs.index(seq), seq.output_text,
|
output = CompletionOutput(seqs.index(seq), seq.output_text,
|
||||||
seq.get_output_token_ids(),
|
seq.get_output_token_ids(),
|
||||||
seq.get_cumulative_logprob(), logprobs)
|
seq.get_cumulative_logprob(), logprobs,
|
||||||
|
finshed_reason)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
# Every sequence in the sequence group should have the same prompt.
|
# Every sequence in the sequence group should have the same prompt.
|
||||||
prompt = top_n_seqs[0].prompt
|
prompt = top_n_seqs[0].prompt
|
||||||
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
|
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
|
||||||
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
|
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs)
|
||||||
seq_group.is_finished())
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"RequestOutput(request_id={self.request_id}, "
|
return (f"RequestOutput(request_id={self.request_id}, "
|
||||||
f"prompt={self.prompt!r}, "
|
f"prompt={self.prompt!r}, "
|
||||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||||
f"outputs={self.outputs}, "
|
f"outputs={self.outputs})")
|
||||||
f"done={self.done})")
|
|
||||||
|
def finished(self) -> bool:
|
||||||
|
return all(output.finished() for output in self.outputs)
|
||||||
|
@ -53,7 +53,7 @@ class SamplingParams:
|
|||||||
stop: Union[str, List[str]] = [],
|
stop: Union[str, List[str]] = [],
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
max_tokens: int = 16,
|
max_tokens: int = 16,
|
||||||
logprobs: int = 0,
|
logprobs: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.n = n
|
self.n = n
|
||||||
self.best_of = best_of if best_of is not None else n
|
self.best_of = best_of if best_of is not None else n
|
||||||
@ -98,7 +98,7 @@ class SamplingParams:
|
|||||||
if self.max_tokens < 1:
|
if self.max_tokens < 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
||||||
if self.logprobs < 0:
|
if self.logprobs is not None and self.logprobs < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import enum
|
import enum
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from cacheflow.block import LogicalTokenBlock
|
from cacheflow.block import LogicalTokenBlock
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
@ -10,8 +10,25 @@ class SequenceStatus(enum.Enum):
|
|||||||
WAITING = enum.auto()
|
WAITING = enum.auto()
|
||||||
RUNNING = enum.auto()
|
RUNNING = enum.auto()
|
||||||
SWAPPED = enum.auto()
|
SWAPPED = enum.auto()
|
||||||
FINISHED = enum.auto()
|
FINISHED_STOPPED = enum.auto()
|
||||||
|
FINISHED_LENGTH_CAPPED = enum.auto()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_finished(status: "SequenceStatus") -> bool:
|
||||||
|
return status in [
|
||||||
|
SequenceStatus.FINISHED_STOPPED,
|
||||||
|
SequenceStatus.FINISHED_LENGTH_CAPPED,
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
|
||||||
|
if status == SequenceStatus.FINISHED_STOPPED:
|
||||||
|
finish_reason = "stop"
|
||||||
|
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
|
||||||
|
finish_reason = "length"
|
||||||
|
else:
|
||||||
|
finish_reason = None
|
||||||
|
return finish_reason
|
||||||
|
|
||||||
class SequenceData:
|
class SequenceData:
|
||||||
|
|
||||||
@ -20,7 +37,6 @@ class SequenceData:
|
|||||||
prompt_token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.prompt_token_ids = prompt_token_ids
|
self.prompt_token_ids = prompt_token_ids
|
||||||
|
|
||||||
self.output_token_ids: List[int] = []
|
self.output_token_ids: List[int] = []
|
||||||
self.cumulative_logprob = 0.0
|
self.cumulative_logprob = 0.0
|
||||||
|
|
||||||
@ -166,7 +182,7 @@ class SequenceGroup:
|
|||||||
raise ValueError(f'Sequence {seq_id} not found.')
|
raise ValueError(f'Sequence {seq_id} not found.')
|
||||||
|
|
||||||
def is_finished(self) -> bool:
|
def is_finished(self) -> bool:
|
||||||
return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs)
|
return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SequenceGroup(request_id={self.request_id}, "
|
return (f"SequenceGroup(request_id={self.request_id}, "
|
||||||
|
@ -1,26 +1,20 @@
|
|||||||
import argparse
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict
|
from typing import Dict, Optional
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
import ray
|
import ray
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
from cacheflow.outputs import RequestOutput
|
from cacheflow.outputs import RequestOutput
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.server.arg_utils import ServerArgs
|
from cacheflow.server.arg_utils import ServerArgs
|
||||||
from cacheflow.server.llm_server import LLMServer
|
from cacheflow.server.llm_server import LLMServer
|
||||||
from cacheflow.server.ray_utils import initialize_cluster
|
from cacheflow.server.ray_utils import initialize_cluster
|
||||||
|
from cacheflow.utils import random_uuid
|
||||||
|
|
||||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
|
|
||||||
class FastAPIServer:
|
class AsyncLLMServer:
|
||||||
|
|
||||||
def __init__(self, server_use_ray: bool, *args, **kwargs) -> None:
|
def __init__(self, server_use_ray: bool, *args, **kwargs) -> None:
|
||||||
if server_use_ray:
|
if server_use_ray:
|
||||||
@ -45,15 +39,15 @@ class FastAPIServer:
|
|||||||
self.request_outputs[request_id] = request_output
|
self.request_outputs[request_id] = request_output
|
||||||
self.request_events[request_id].set()
|
self.request_events[request_id].set()
|
||||||
|
|
||||||
async def generate(self, request_dict: Dict[str, Any]):
|
async def generate(self, prompt: str, sampling_params: SamplingParams,
|
||||||
|
request_id: Optional[str] = None) -> RequestOutput:
|
||||||
# Preprocess the request.
|
# Preprocess the request.
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
prompt = request_dict.pop("prompt")
|
|
||||||
sampling_params = SamplingParams(**request_dict)
|
|
||||||
|
|
||||||
# Create an event to notify us that there is new output from the
|
# Create an event to notify us that there is new output from the
|
||||||
# cacheflow server.
|
# cacheflow server.
|
||||||
request_id = str(uuid.uuid4().hex[:8])
|
if request_id is None:
|
||||||
|
request_id = random_uuid()
|
||||||
request_event = asyncio.Event()
|
request_event = asyncio.Event()
|
||||||
self.request_events[request_id] = request_event
|
self.request_events[request_id] = request_event
|
||||||
|
|
||||||
@ -82,19 +76,10 @@ class FastAPIServer:
|
|||||||
|
|
||||||
# Decode and return new outputs.
|
# Decode and return new outputs.
|
||||||
request_output = self.request_outputs[request_id]
|
request_output = self.request_outputs[request_id]
|
||||||
prompt = request_output.prompt
|
yield request_output
|
||||||
text_outputs = [
|
|
||||||
prompt + output.text
|
|
||||||
for output in request_output.outputs
|
|
||||||
]
|
|
||||||
ret = {
|
|
||||||
"text": text_outputs,
|
|
||||||
"error": 0,
|
|
||||||
}
|
|
||||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
|
||||||
|
|
||||||
# Once finished, release the resources of the sequence group.
|
# Once finished, release the resources of the sequence group.
|
||||||
if request_output.done:
|
if request_output.finished():
|
||||||
del self.request_outputs[request_id]
|
del self.request_outputs[request_id]
|
||||||
del self.request_events[request_id]
|
del self.request_events[request_id]
|
||||||
# Kick the server if the server is not running. This is to
|
# Kick the server if the server is not running. This is to
|
||||||
@ -104,25 +89,15 @@ class FastAPIServer:
|
|||||||
await self.server_step()
|
await self.server_step()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@classmethod
|
||||||
@app.post("/generate")
|
def from_server_args(cls, server_args: ServerArgs) -> "AsyncLLMServer":
|
||||||
async def generate_stream(request: Request):
|
# Create the server configs.
|
||||||
request_dict = await request.json()
|
server_configs = server_args.create_server_configs()
|
||||||
return StreamingResponse(server.generate(request_dict))
|
parallel_config = server_configs[2]
|
||||||
|
# Initialize the cluster.
|
||||||
|
distributed_init_method, devices = initialize_cluster(parallel_config)
|
||||||
if __name__ == "__main__":
|
# Create the LLM server.
|
||||||
parser = argparse.ArgumentParser()
|
server = cls(server_args.use_ray, *server_configs,
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
distributed_init_method, devices,
|
||||||
parser.add_argument("--port", type=int, default=10002)
|
log_stats=not server_args.disable_log_stats)
|
||||||
parser = ServerArgs.add_cli_args(parser)
|
return server
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
server_configs = ServerArgs.from_cli_args(args).create_server_configs()
|
|
||||||
parallel_config = server_configs[2]
|
|
||||||
distributed_init_method, stage_devices = initialize_cluster(parallel_config)
|
|
||||||
|
|
||||||
server = FastAPIServer(args.use_ray, *server_configs,
|
|
||||||
distributed_init_method, stage_devices,
|
|
||||||
log_stats=not args.disable_log_stats)
|
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
|
@ -210,7 +210,8 @@ class LLMServer:
|
|||||||
# Truncate the output text so that the stop string is
|
# Truncate the output text so that the stop string is
|
||||||
# not included in the output.
|
# not included in the output.
|
||||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||||
self.scheduler.free_seq(seq)
|
self.scheduler.free_seq(seq,
|
||||||
|
SequenceStatus.FINISHED_STOPPED)
|
||||||
stopped = True
|
stopped = True
|
||||||
break
|
break
|
||||||
if stopped:
|
if stopped:
|
||||||
@ -218,12 +219,14 @@ class LLMServer:
|
|||||||
|
|
||||||
# Check if the sequence has reached max_tokens.
|
# Check if the sequence has reached max_tokens.
|
||||||
if seq.get_output_len() == sampling_params.max_tokens:
|
if seq.get_output_len() == sampling_params.max_tokens:
|
||||||
self.scheduler.free_seq(seq)
|
self.scheduler.free_seq(
|
||||||
|
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
|
||||||
continue
|
continue
|
||||||
# Check if the sequence has generated the EOS token.
|
# Check if the sequence has generated the EOS token.
|
||||||
if not sampling_params.ignore_eos:
|
if not sampling_params.ignore_eos:
|
||||||
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
|
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
|
||||||
self.scheduler.free_seq(seq)
|
self.scheduler.free_seq(seq,
|
||||||
|
SequenceStatus.FINISHED_STOPPED)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def _run_workers(
|
def _run_workers(
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import enum
|
import enum
|
||||||
|
import uuid
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
@ -31,3 +32,7 @@ def get_gpu_memory(gpu: int = 0) -> int:
|
|||||||
def get_cpu_memory() -> int:
|
def get_cpu_memory() -> int:
|
||||||
"""Returns the total CPU memory of the node in bytes."""
|
"""Returns the total CPU memory of the node in bytes."""
|
||||||
return psutil.virtual_memory().total
|
return psutil.virtual_memory().total
|
||||||
|
|
||||||
|
|
||||||
|
def random_uuid() -> str:
|
||||||
|
return str(uuid.uuid4().hex)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import requests
|
import requests
|
||||||
@ -24,9 +23,9 @@ def http_bot(prompt):
|
|||||||
def build_demo():
|
def build_demo():
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
"# Cacheflow demo\n"
|
"# Cacheflow text completion demo\n"
|
||||||
)
|
)
|
||||||
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")# .style(container=False)
|
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
|
||||||
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
|
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
|
||||||
inputbox.submit(http_bot, [inputbox], [outputbox])
|
inputbox.submit(http_bot, [inputbox], [outputbox])
|
||||||
return demo
|
return demo
|
||||||
@ -35,9 +34,11 @@ def build_demo():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
parser.add_argument("--port", type=int, default=10003)
|
parser.add_argument("--port", type=int, default=8002)
|
||||||
parser.add_argument("--model-url", type=str, default="http://localhost:10002/generate")
|
parser.add_argument("--model-url", type=str, default="http://localhost:8001/generate")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
demo = build_demo()
|
demo = build_demo()
|
||||||
demo.queue(concurrency_count=100).launch(server_name=args.host, server_port=args.port)
|
demo.queue(concurrency_count=100).launch(server_name=args.host,
|
||||||
|
server_port=args.port,
|
||||||
|
share=True)
|
22
examples/openai_client.py
Normal file
22
examples/openai_client.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import openai
|
||||||
|
openai.api_key = "EMPTY"
|
||||||
|
openai.api_base = "http://localhost:8000/v1"
|
||||||
|
model = "facebook/opt-125m"
|
||||||
|
|
||||||
|
# list models
|
||||||
|
models = openai.Model.list()
|
||||||
|
print(models)
|
||||||
|
|
||||||
|
# create a completion
|
||||||
|
|
||||||
|
stream = True
|
||||||
|
completion = openai.Completion.create(
|
||||||
|
model=model, prompt="A robot may not injure a human being", echo=False, n=2,
|
||||||
|
best_of=3, stream=stream, logprobs=3)
|
||||||
|
|
||||||
|
# print the completion
|
||||||
|
if stream:
|
||||||
|
for c in completion:
|
||||||
|
print(c)
|
||||||
|
else:
|
||||||
|
print("completion:", completion)
|
48
examples/simple_fastapi_client.py
Normal file
48
examples/simple_fastapi_client.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import argparse
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
def clear_line(n=1):
|
||||||
|
LINE_UP = '\033[1A'
|
||||||
|
LINE_CLEAR = '\x1b[2K'
|
||||||
|
for i in range(n):
|
||||||
|
print(LINE_UP, end=LINE_CLEAR, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def http_request(prompt: str, api_url: str, n: int = 1):
|
||||||
|
headers = {"User-Agent": "Test Client"}
|
||||||
|
pload = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"n": n,
|
||||||
|
"use_beam_search": True,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": 16,
|
||||||
|
}
|
||||||
|
response = requests.post(api_url, headers=headers, json=pload, stream=True)
|
||||||
|
|
||||||
|
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
||||||
|
if chunk:
|
||||||
|
data = json.loads(chunk.decode("utf-8"))
|
||||||
|
output = data["text"]
|
||||||
|
yield output
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=int, default=8001)
|
||||||
|
parser.add_argument("--n", type=int, default=4)
|
||||||
|
parser.add_argument("--prompt", type=str, default="San Francisco is a")
|
||||||
|
args = parser.parse_args()
|
||||||
|
prompt = args.prompt
|
||||||
|
api_url = f"http://{args.host}:{args.port}/generate"
|
||||||
|
n = args.n
|
||||||
|
|
||||||
|
print(f"Prompt: {prompt}\n", flush=True)
|
||||||
|
num_printed_lines = 0
|
||||||
|
for h in http_request(prompt, api_url, n):
|
||||||
|
clear_line(num_printed_lines)
|
||||||
|
num_printed_lines = 0
|
||||||
|
for i, line in enumerate(h):
|
||||||
|
num_printed_lines += 1
|
||||||
|
print(f"Beam candidate {i}: {line}", flush=True)
|
@ -1,5 +1,4 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import uuid
|
|
||||||
|
|
||||||
from cacheflow import ServerArgs, LLMServer, SamplingParams
|
from cacheflow import ServerArgs, LLMServer, SamplingParams
|
||||||
|
|
||||||
@ -20,17 +19,19 @@ def main(args: argparse.Namespace):
|
|||||||
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
|
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
request_id = 0
|
||||||
|
|
||||||
# Run the server.
|
# Run the server.
|
||||||
while True:
|
while True:
|
||||||
# To test iteration-level scheduling, we add one request at each step.
|
# To test iteration-level scheduling, we add one request at each step.
|
||||||
if test_prompts:
|
if test_prompts:
|
||||||
prompt, sampling_params = test_prompts.pop(0)
|
prompt, sampling_params = test_prompts.pop(0)
|
||||||
request_id = str(uuid.uuid4().hex[:8])
|
server.add_request(str(request_id), prompt, sampling_params)
|
||||||
server.add_request(request_id, prompt, sampling_params)
|
request_id += 1
|
||||||
|
|
||||||
request_outputs = server.step()
|
request_outputs = server.step()
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
if request_output.done:
|
if request_output.finished():
|
||||||
print(request_output)
|
print(request_output)
|
||||||
|
|
||||||
if not (server.has_unfinished_requests() or test_prompts):
|
if not (server.has_unfinished_requests() or test_prompts):
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
import requests
|
|
||||||
import json
|
|
||||||
|
|
||||||
def http_bot():
|
|
||||||
prompt = "How are you? I'm fine."
|
|
||||||
|
|
||||||
headers = {"User-Agent": "Test Client"}
|
|
||||||
pload = {
|
|
||||||
"prompt": prompt,
|
|
||||||
}
|
|
||||||
response = requests.post("http://localhost:10002", headers=headers, json=pload, stream=True)
|
|
||||||
|
|
||||||
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
|
||||||
if chunk:
|
|
||||||
data = json.loads(chunk.decode("utf-8"))
|
|
||||||
output = data["text"]
|
|
||||||
yield output
|
|
||||||
|
|
||||||
for h in http_bot():
|
|
||||||
print(h, end="", flush=True)
|
|
@ -1,40 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from typing import Union
|
|
||||||
import json
|
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
|
|
||||||
async def text_streamer(args):
|
|
||||||
context = args["prompt"]
|
|
||||||
words = context.split(" ")
|
|
||||||
for word in words:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
print("word:", word)
|
|
||||||
ret = {
|
|
||||||
"text": word + " ",
|
|
||||||
"error": 0,
|
|
||||||
}
|
|
||||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/")
|
|
||||||
async def read_root(request: Request):
|
|
||||||
args = await request.json()
|
|
||||||
return StreamingResponse(text_streamer(args))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
|
||||||
parser.add_argument("--port", type=int, default=10002)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
|
@ -8,3 +8,4 @@ transformers >= 4.28.0 # Required for LLaMA.
|
|||||||
xformers >= 0.0.19
|
xformers >= 0.0.19
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
|
pydantic # Required for OpenAI server.
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
import requests
|
|
||||||
import json
|
|
||||||
|
|
||||||
def http_request():
|
|
||||||
prompt = "Ion Stoica is a"
|
|
||||||
|
|
||||||
headers = {"User-Agent": "Test Client"}
|
|
||||||
pload = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"n": 4,
|
|
||||||
"use_beam_search": True,
|
|
||||||
"temperature": 0.0,
|
|
||||||
}
|
|
||||||
response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True)
|
|
||||||
|
|
||||||
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
|
||||||
if chunk:
|
|
||||||
data = json.loads(chunk.decode("utf-8"))
|
|
||||||
output = data["text"]
|
|
||||||
yield output
|
|
||||||
|
|
||||||
for h in http_request():
|
|
||||||
print(h, flush=True)
|
|
Loading…
x
Reference in New Issue
Block a user