diff --git a/cacheflow/core/block_manager.py b/cacheflow/core/block_manager.py index 8f1295bc..07129b65 100644 --- a/cacheflow/core/block_manager.py +++ b/cacheflow/core/block_manager.py @@ -148,7 +148,7 @@ class BlockSpaceManager: # the sequences in the same group. blocks: Set[PhysicalTokenBlock] = set() for seq in seq_group.get_seqs(): - if seq.status == SequenceStatus.FINISHED: + if SequenceStatus.is_finished(seq.status): continue block_table = self.block_tables[seq.seq_id] for block in block_table: @@ -169,7 +169,7 @@ class BlockSpaceManager: # CPU block -> GPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(): - if seq.status == SequenceStatus.FINISHED: + if SequenceStatus.is_finished(seq.status): continue new_block_table: BlockTable = [] block_table = self.block_tables[seq.seq_id] @@ -200,7 +200,7 @@ class BlockSpaceManager: # GPU block -> CPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(): - if seq.status == SequenceStatus.FINISHED: + if SequenceStatus.is_finished(seq.status): continue new_block_table: BlockTable = [] block_table = self.block_tables[seq.seq_id] diff --git a/cacheflow/core/scheduler.py b/cacheflow/core/scheduler.py index 1085e839..b4932c69 100644 --- a/cacheflow/core/scheduler.py +++ b/cacheflow/core/scheduler.py @@ -292,10 +292,12 @@ class Scheduler: # Append a new token to the sequence. output = seq_outputs[seq.seq_id] seq.append_token_id(output.output_token, output.logprobs) + # Return a shallow copy of the running queue to prevent the queue + # from being modified by the caller. return self.running.copy() - def free_seq(self, seq: Sequence) -> None: - seq.status = SequenceStatus.FINISHED + def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None: + seq.status = finish_status self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py new file mode 100644 index 00000000..4d32390b --- /dev/null +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -0,0 +1,300 @@ +# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py + +import argparse +from http import HTTPStatus +import json +import time +from typing import AsyncGenerator, Dict, List, Optional + +import fastapi +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn + +from cacheflow.outputs import RequestOutput +from cacheflow.server.arg_utils import ServerArgs +from cacheflow.server.async_llm_server import AsyncLLMServer +from cacheflow.server.tokenizer_utils import get_tokenizer +from cacheflow.logger import init_logger +from cacheflow.sampling_params import SamplingParams +from cacheflow.utils import random_uuid +from cacheflow.entrypoints.openai.protocol import ( + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + ErrorResponse, + LogProbs, + ModelCard, + ModelList, + ModelPermission, + UsageInfo, +) + + +logger = init_logger(__name__) +served_model = None +app = fastapi.FastAPI() + + +def create_error_response(status_code: HTTPStatus, + message: str) -> JSONResponse: + return JSONResponse( + ErrorResponse(message=message, type="invalid_request_error").dict(), + status_code=status_code.value + ) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): + return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) + + +async def check_model(request) -> Optional[JSONResponse]: + if request.model == served_model: + return + ret = create_error_response( + HTTPStatus.NOT_FOUND, + f"The model `{request.model}` does not exist.", + ) + return ret + + +@app.get("/v1/models") +async def show_available_models(): + """Show available models. Right now we only have one model.""" + model_cards = [ModelCard(id=served_model, root=served_model, + permission=[ModelPermission()])] + return ModelList(data=model_cards) + + +def create_logprobs(token_ids: List[int], + id_logprobs: List[Dict[int, float]], + initial_text_offset: int = 0) -> LogProbs: + """Create OpenAI-style logprobs.""" + logprobs = LogProbs() + last_token_len = 0 + for token_id, id_logprob in zip(token_ids, id_logprobs): + token = tokenizer.convert_ids_to_tokens(token_id) + logprobs.tokens.append(token) + logprobs.token_logprobs.append(id_logprob[token_id]) + if len(logprobs.text_offset) == 0: + logprobs.text_offset.append(initial_text_offset) + else: + logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) + last_token_len = len(token) + + logprobs.top_logprobs.append( + {tokenizer.convert_ids_to_tokens(i): p + for i, p in id_logprob.items()}) + return logprobs + + +@app.post("/v1/completions") +async def create_completion(request: CompletionRequest): + logger.info(f"Received completion request: {request}") + + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + if request.echo: + # We do not support echo since the cacheflow server does not + # currently support getting the logprobs of prompt tokens. + return create_error_response(HTTPStatus.BAD_REQUEST, + "echo is not currently supported") + + if request.suffix is not None: + # The language models we currently support do not support suffix. + return create_error_response(HTTPStatus.BAD_REQUEST, + "suffix is not currently supported") + + if request.logit_bias is not None: + # TODO: support logit_bias in cacheflow server. + return create_error_response(HTTPStatus.BAD_REQUEST, + "logit_bias is not currently supported") + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + prompt = request.prompt + created_time = int(time.time()) + try: + sampling_params = SamplingParams( + n=request.n, + best_of=request.best_of, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + stop=request.stop, + ignore_eos=request.ignore_eos, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + use_beam_search=request.use_beam_search, + ) + except ValueError as e: + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + + result_generator = server.generate(prompt, sampling_params, + request_id=request_id) + + # Similar to the OpenAI API, when n != best_of, we do not stream the + # results. In addition, we do not stream the results when use beam search. + stream = (request.stream and + (request.best_of is None or request.n == request.best_of) and + not request.use_beam_search) + + def create_stream_response_json(index: int, + text: str, + logprobs: Optional[LogProbs] = None, + finish_reason: Optional[str] = None) -> str: + choice_data = CompletionResponseStreamChoice( + index=index, + text=text, + logprobs=logprobs, + finish_reason=finish_reason, + ) + response = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + ) + response_json = response.json(ensure_ascii=False) + + return response_json + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n + async for res in result_generator: + res: RequestOutput + for output in res.outputs: + i = output.index + delta_text = output.text[len(previous_texts[i]):] + if request.logprobs is not None: + logprobs = create_logprobs( + output.token_ids[previous_num_tokens[i]:], + output.logprobs[previous_num_tokens[i]:], + len(previous_texts[i])) + else: + logprobs = None + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + response_json = create_stream_response_json( + index=i, + text=delta_text, + logprobs=logprobs, + ) + yield f"data: {response_json}\n\n" + if output.finish_reason is not None: + logprobs = LogProbs() if request.logprobs is not None else None + response_json = create_stream_response_json( + index=i, + text="", + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + # Streaming response + if stream: + return StreamingResponse(completion_stream_generator(), + media_type="text/event-stream") + + # Non-streaming response + final_res: RequestOutput = None + async for res in result_generator: + final_res = res + assert final_res is not None + choices = [] + for output in final_res.outputs: + if request.logprobs is not None: + logprobs = create_logprobs(output.token_ids, output.logprobs) + else: + logprobs = None + choice_data = CompletionResponseChoice( + index=output.index, + text=output.text, + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum(len(output.token_ids) + for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + if request.stream: + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + response_json = response.json(ensure_ascii=False) + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + return StreamingResponse(fake_stream_generator(), + media_type="text/event-stream") + + return response + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="CacheFlow OpenAI-Compatible RESTful API server." + ) + parser.add_argument("--host", type=str, default="localhost", help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" + ) + parser.add_argument( + "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" + ) + parser.add_argument( + "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" + ) + parser.add_argument( + "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" + ) + parser.add_argument("--served-model-name", type=str, default=None, + help="The model name used in the API. If not specified, " + "the model name will be the same as the " + "huggingface name.") + parser = ServerArgs.add_cli_args(parser) + args = parser.parse_args() + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + logger.info(f"args: {args}") + + served_model = args.served_model_name or args.model + + server_args = ServerArgs.from_cli_args(args) + server = AsyncLLMServer.from_server_args(server_args) + + # A separate tokenizer to map token IDs to strings. + tokenizer = get_tokenizer(args.model) + + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/cacheflow/entrypoints/openai/protocol.py b/cacheflow/entrypoints/openai/protocol.py new file mode 100644 index 00000000..61ad60c2 --- /dev/null +++ b/cacheflow/entrypoints/openai/protocol.py @@ -0,0 +1,126 @@ +# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import time +from typing import Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, Field + +from cacheflow.utils import random_uuid + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: Optional[str] = None + + +class ModelPermission(BaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = False + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: str = False + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "cacheflow" + root: Optional[str] = None + parent: Optional[str] = None + permission: List[ModelPermission] = Field(default_factory=list) + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[Dict[str, str]] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + max_tokens: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + + +class CompletionRequest(BaseModel): + model: str + prompt: str + suffix: Optional[str] = None + max_tokens: Optional[int] = 16 + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + stream: Optional[bool] = False + logprobs: Optional[int] = None + echo: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + best_of: Optional[int] = None + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + # Additional parameters supported by cacheflow + top_k: Optional[int] = -1 + ignore_eos: Optional[bool] = False + use_beam_search: Optional[bool] = False + + +class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] diff --git a/cacheflow/entrypoints/simple_fastapi_frontend.py b/cacheflow/entrypoints/simple_fastapi_frontend.py new file mode 100644 index 00000000..e7e1357f --- /dev/null +++ b/cacheflow/entrypoints/simple_fastapi_frontend.py @@ -0,0 +1,51 @@ +import argparse +import json +from typing import AsyncGenerator + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import uvicorn + +from cacheflow.sampling_params import SamplingParams +from cacheflow.server.arg_utils import ServerArgs +from cacheflow.server.async_llm_server import AsyncLLMServer +from cacheflow.server.ray_utils import initialize_cluster + +TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds +app = FastAPI() + + +@app.post("/generate") +async def generate_stream(request: Request) -> StreamingResponse: + request_dict = await request.json() + prompt = request_dict.pop("prompt") + sampling_params = SamplingParams(**request_dict) + results_generator = server.generate(prompt, sampling_params) + + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + text_outputs = [ + prompt + output.text + for output in request_output.outputs + ] + ret = { + "text": text_outputs, + "error": 0, + } + yield (json.dumps(ret) + "\0").encode("utf-8") + + return StreamingResponse(stream_results()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8001) + parser = ServerArgs.add_cli_args(parser) + args = parser.parse_args() + + server_args = ServerArgs.from_cli_args(args) + server = AsyncLLMServer.from_server_args(server_args) + + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/cacheflow/model_executor/layers/sampler.py b/cacheflow/model_executor/layers/sampler.py index 425d5385..7782801e 100644 --- a/cacheflow/model_executor/layers/sampler.py +++ b/cacheflow/model_executor/layers/sampler.py @@ -1,5 +1,5 @@ """A layer that samples the next tokens from the model's outputs.""" -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional import numpy as np import torch @@ -258,9 +258,9 @@ def _apply_top_p_top_k( def _get_topk_logprobs( logprobs: torch.Tensor, - num_logprobs: int, + num_logprobs: Optional[int], ) -> Dict[int, float]: - if num_logprobs == 0: + if num_logprobs is None or num_logprobs == 0: return {} topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs) diff --git a/cacheflow/outputs.py b/cacheflow/outputs.py index 18b9a7cd..5d2b4f05 100644 --- a/cacheflow/outputs.py +++ b/cacheflow/outputs.py @@ -1,6 +1,6 @@ -from typing import Dict, List +from typing import Dict, List, Optional -from cacheflow.sequence import SequenceGroup +from cacheflow.sequence import SequenceGroup, SequenceStatus class CompletionOutput: @@ -12,19 +12,25 @@ class CompletionOutput: token_ids: List[int], cumulative_logprob: float, logprobs: List[Dict[int, float]], + finish_reason: Optional[str] = None, ) -> None: self.index = index self.text = text self.token_ids = token_ids self.cumulative_logprob = cumulative_logprob self.logprobs = logprobs + self.finish_reason = finish_reason + + def finished(self) -> bool: + return self.finish_reason is not None def __repr__(self) -> str: return (f"CompletionOutput(index={self.index}, " f"text={self.text!r}, " f"token_ids={self.token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " - f"logprobs={self.logprobs})") + f"logprobs={self.logprobs}," + f"finish_reason={self.finish_reason})") class RequestOutput: @@ -35,13 +41,11 @@ class RequestOutput: prompt: str, prompt_token_ids: List[int], outputs: List[CompletionOutput], - done: bool, ) -> None: self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids self.outputs = outputs - self.done = done @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": @@ -57,25 +61,28 @@ class RequestOutput: outputs: List[CompletionOutput] = [] for seq in top_n_seqs: logprobs = seq.output_logprobs - if seq_group.sampling_params.logprobs == 0: + if seq_group.sampling_params.logprobs is None: # NOTE: We need to take care of this case because the sequence # always has the logprobs of the sampled tokens even if the # logprobs are not requested. logprobs = {} + finshed_reason = SequenceStatus.get_finished_reason(seq.status) output = CompletionOutput(seqs.index(seq), seq.output_text, seq.get_output_token_ids(), - seq.get_cumulative_logprob(), logprobs) + seq.get_cumulative_logprob(), logprobs, + finshed_reason) outputs.append(output) # Every sequence in the sequence group should have the same prompt. prompt = top_n_seqs[0].prompt prompt_token_ids = top_n_seqs[0].data.prompt_token_ids - return cls(seq_group.request_id, prompt, prompt_token_ids, outputs, - seq_group.is_finished()) + return cls(seq_group.request_id, prompt, prompt_token_ids, outputs) def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " - f"outputs={self.outputs}, " - f"done={self.done})") + f"outputs={self.outputs})") + + def finished(self) -> bool: + return all(output.finished() for output in self.outputs) diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index 0ce772a9..031eb820 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -53,7 +53,7 @@ class SamplingParams: stop: Union[str, List[str]] = [], ignore_eos: bool = False, max_tokens: int = 16, - logprobs: int = 0, + logprobs: Optional[int] = None, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -98,7 +98,7 @@ class SamplingParams: if self.max_tokens < 1: raise ValueError( f"max_tokens must be at least 1, got {self.max_tokens}.") - if self.logprobs < 0: + if self.logprobs is not None and self.logprobs < 0: raise ValueError( f"logprobs must be non-negative, got {self.logprobs}.") diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index 02c5970e..db864609 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -1,6 +1,6 @@ import copy import enum -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from cacheflow.block import LogicalTokenBlock from cacheflow.sampling_params import SamplingParams @@ -10,8 +10,25 @@ class SequenceStatus(enum.Enum): WAITING = enum.auto() RUNNING = enum.auto() SWAPPED = enum.auto() - FINISHED = enum.auto() + FINISHED_STOPPED = enum.auto() + FINISHED_LENGTH_CAPPED = enum.auto() + @staticmethod + def is_finished(status: "SequenceStatus") -> bool: + return status in [ + SequenceStatus.FINISHED_STOPPED, + SequenceStatus.FINISHED_LENGTH_CAPPED, + ] + + @staticmethod + def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: + if status == SequenceStatus.FINISHED_STOPPED: + finish_reason = "stop" + elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: + finish_reason = "length" + else: + finish_reason = None + return finish_reason class SequenceData: @@ -20,7 +37,6 @@ class SequenceData: prompt_token_ids: List[int], ) -> None: self.prompt_token_ids = prompt_token_ids - self.output_token_ids: List[int] = [] self.cumulative_logprob = 0.0 @@ -166,7 +182,7 @@ class SequenceGroup: raise ValueError(f'Sequence {seq_id} not found.') def is_finished(self) -> bool: - return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs) + return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs) def __repr__(self) -> str: return (f"SequenceGroup(request_id={self.request_id}, " diff --git a/cacheflow/entrypoints/fastapi_server.py b/cacheflow/server/async_llm_server.py similarity index 66% rename from cacheflow/entrypoints/fastapi_server.py rename to cacheflow/server/async_llm_server.py index f69b82bc..8755b023 100644 --- a/cacheflow/entrypoints/fastapi_server.py +++ b/cacheflow/server/async_llm_server.py @@ -1,26 +1,20 @@ -import argparse import asyncio -import json import time -from typing import Any, Dict -import uuid +from typing import Dict, Optional -from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse import ray -import uvicorn from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.llm_server import LLMServer from cacheflow.server.ray_utils import initialize_cluster +from cacheflow.utils import random_uuid TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds -app = FastAPI() -class FastAPIServer: +class AsyncLLMServer: def __init__(self, server_use_ray: bool, *args, **kwargs) -> None: if server_use_ray: @@ -45,15 +39,15 @@ class FastAPIServer: self.request_outputs[request_id] = request_output self.request_events[request_id].set() - async def generate(self, request_dict: Dict[str, Any]): + async def generate(self, prompt: str, sampling_params: SamplingParams, + request_id: Optional[str] = None) -> RequestOutput: # Preprocess the request. arrival_time = time.time() - prompt = request_dict.pop("prompt") - sampling_params = SamplingParams(**request_dict) # Create an event to notify us that there is new output from the # cacheflow server. - request_id = str(uuid.uuid4().hex[:8]) + if request_id is None: + request_id = random_uuid() request_event = asyncio.Event() self.request_events[request_id] = request_event @@ -82,19 +76,10 @@ class FastAPIServer: # Decode and return new outputs. request_output = self.request_outputs[request_id] - prompt = request_output.prompt - text_outputs = [ - prompt + output.text - for output in request_output.outputs - ] - ret = { - "text": text_outputs, - "error": 0, - } - yield (json.dumps(ret) + "\0").encode("utf-8") + yield request_output # Once finished, release the resources of the sequence group. - if request_output.done: + if request_output.finished(): del self.request_outputs[request_id] del self.request_events[request_id] # Kick the server if the server is not running. This is to @@ -104,25 +89,15 @@ class FastAPIServer: await self.server_step() break - -@app.post("/generate") -async def generate_stream(request: Request): - request_dict = await request.json() - return StreamingResponse(server.generate(request_dict)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=10002) - parser = ServerArgs.add_cli_args(parser) - args = parser.parse_args() - - server_configs = ServerArgs.from_cli_args(args).create_server_configs() - parallel_config = server_configs[2] - distributed_init_method, stage_devices = initialize_cluster(parallel_config) - - server = FastAPIServer(args.use_ray, *server_configs, - distributed_init_method, stage_devices, - log_stats=not args.disable_log_stats) - uvicorn.run(app, host=args.host, port=args.port, log_level="info") + @classmethod + def from_server_args(cls, server_args: ServerArgs) -> "AsyncLLMServer": + # Create the server configs. + server_configs = server_args.create_server_configs() + parallel_config = server_configs[2] + # Initialize the cluster. + distributed_init_method, devices = initialize_cluster(parallel_config) + # Create the LLM server. + server = cls(server_args.use_ray, *server_configs, + distributed_init_method, devices, + log_stats=not server_args.disable_log_stats) + return server diff --git a/cacheflow/server/llm_server.py b/cacheflow/server/llm_server.py index c35b7f93..b162b2a2 100644 --- a/cacheflow/server/llm_server.py +++ b/cacheflow/server/llm_server.py @@ -210,7 +210,8 @@ class LLMServer: # Truncate the output text so that the stop string is # not included in the output. seq.output_text = seq.output_text[:-len(stop_str)] - self.scheduler.free_seq(seq) + self.scheduler.free_seq(seq, + SequenceStatus.FINISHED_STOPPED) stopped = True break if stopped: @@ -218,12 +219,14 @@ class LLMServer: # Check if the sequence has reached max_tokens. if seq.get_output_len() == sampling_params.max_tokens: - self.scheduler.free_seq(seq) + self.scheduler.free_seq( + seq, SequenceStatus.FINISHED_LENGTH_CAPPED) continue # Check if the sequence has generated the EOS token. if not sampling_params.ignore_eos: if seq.get_last_token_id() == self.tokenizer.eos_token_id: - self.scheduler.free_seq(seq) + self.scheduler.free_seq(seq, + SequenceStatus.FINISHED_STOPPED) continue def _run_workers( @@ -238,10 +241,10 @@ class LLMServer: executor = getattr(worker, method) if self.parallel_config.use_ray: executor = executor.remote - + output = executor(*args, **kwargs) all_outputs.append(output) - + if self.parallel_config.use_ray: all_outputs = ray.get(all_outputs) diff --git a/cacheflow/utils.py b/cacheflow/utils.py index 4a40d16f..85fe1877 100644 --- a/cacheflow/utils.py +++ b/cacheflow/utils.py @@ -1,4 +1,5 @@ import enum +import uuid import psutil import torch @@ -31,3 +32,7 @@ def get_gpu_memory(gpu: int = 0) -> int: def get_cpu_memory() -> int: """Returns the total CPU memory of the node in bytes.""" return psutil.virtual_memory().total + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) diff --git a/gradio_webserver.py b/examples/gradio_webserver.py similarity index 80% rename from gradio_webserver.py rename to examples/gradio_webserver.py index d819ecab..e4a80c39 100644 --- a/gradio_webserver.py +++ b/examples/gradio_webserver.py @@ -1,6 +1,5 @@ import argparse import json -import time import gradio as gr import requests @@ -24,9 +23,9 @@ def http_bot(prompt): def build_demo(): with gr.Blocks() as demo: gr.Markdown( - "# Cacheflow demo\n" + "# Cacheflow text completion demo\n" ) - inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")# .style(container=False) + inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER") outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model") inputbox.submit(http_bot, [inputbox], [outputbox]) return demo @@ -35,9 +34,11 @@ def build_demo(): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=10003) - parser.add_argument("--model-url", type=str, default="http://localhost:10002/generate") + parser.add_argument("--port", type=int, default=8002) + parser.add_argument("--model-url", type=str, default="http://localhost:8001/generate") args = parser.parse_args() demo = build_demo() - demo.queue(concurrency_count=100).launch(server_name=args.host, server_port=args.port) \ No newline at end of file + demo.queue(concurrency_count=100).launch(server_name=args.host, + server_port=args.port, + share=True) diff --git a/examples/openai_client.py b/examples/openai_client.py new file mode 100644 index 00000000..9e711a8a --- /dev/null +++ b/examples/openai_client.py @@ -0,0 +1,22 @@ +import openai +openai.api_key = "EMPTY" +openai.api_base = "http://localhost:8000/v1" +model = "facebook/opt-125m" + +# list models +models = openai.Model.list() +print(models) + +# create a completion + +stream = True +completion = openai.Completion.create( + model=model, prompt="A robot may not injure a human being", echo=False, n=2, + best_of=3, stream=stream, logprobs=3) + +# print the completion +if stream: + for c in completion: + print(c) +else: + print("completion:", completion) diff --git a/examples/simple_fastapi_client.py b/examples/simple_fastapi_client.py new file mode 100644 index 00000000..d7d9d355 --- /dev/null +++ b/examples/simple_fastapi_client.py @@ -0,0 +1,48 @@ +import argparse +import requests +import json + +def clear_line(n=1): + LINE_UP = '\033[1A' + LINE_CLEAR = '\x1b[2K' + for i in range(n): + print(LINE_UP, end=LINE_CLEAR, flush=True) + + +def http_request(prompt: str, api_url: str, n: int = 1): + headers = {"User-Agent": "Test Client"} + pload = { + "prompt": prompt, + "n": n, + "use_beam_search": True, + "temperature": 0.0, + "max_tokens": 16, + } + response = requests.post(api_url, headers=headers, json=pload, stream=True) + + for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode("utf-8")) + output = data["text"] + yield output + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8001) + parser.add_argument("--n", type=int, default=4) + parser.add_argument("--prompt", type=str, default="San Francisco is a") + args = parser.parse_args() + prompt = args.prompt + api_url = f"http://{args.host}:{args.port}/generate" + n = args.n + + print(f"Prompt: {prompt}\n", flush=True) + num_printed_lines = 0 + for h in http_request(prompt, api_url, n): + clear_line(num_printed_lines) + num_printed_lines = 0 + for i, line in enumerate(h): + num_printed_lines += 1 + print(f"Beam candidate {i}: {line}", flush=True) diff --git a/examples/simple_server.py b/examples/simple_server.py index 781c05f7..8d5fcaf8 100644 --- a/examples/simple_server.py +++ b/examples/simple_server.py @@ -1,5 +1,4 @@ import argparse -import uuid from cacheflow import ServerArgs, LLMServer, SamplingParams @@ -20,17 +19,19 @@ def main(args: argparse.Namespace): SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)), ] + request_id = 0 + # Run the server. while True: # To test iteration-level scheduling, we add one request at each step. if test_prompts: prompt, sampling_params = test_prompts.pop(0) - request_id = str(uuid.uuid4().hex[:8]) - server.add_request(request_id, prompt, sampling_params) + server.add_request(str(request_id), prompt, sampling_params) + request_id += 1 request_outputs = server.step() for request_output in request_outputs: - if request_output.done: + if request_output.finished(): print(request_output) if not (server.has_unfinished_requests() or test_prompts): diff --git a/playground/http_client.py b/playground/http_client.py deleted file mode 100644 index ac13ac62..00000000 --- a/playground/http_client.py +++ /dev/null @@ -1,20 +0,0 @@ -import requests -import json - -def http_bot(): - prompt = "How are you? I'm fine." - - headers = {"User-Agent": "Test Client"} - pload = { - "prompt": prompt, - } - response = requests.post("http://localhost:10002", headers=headers, json=pload, stream=True) - - for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): - if chunk: - data = json.loads(chunk.decode("utf-8")) - output = data["text"] - yield output - -for h in http_bot(): - print(h, end="", flush=True) \ No newline at end of file diff --git a/playground/streaming_fastapi_worker.py b/playground/streaming_fastapi_worker.py deleted file mode 100644 index 8ab087d1..00000000 --- a/playground/streaming_fastapi_worker.py +++ /dev/null @@ -1,40 +0,0 @@ -import argparse -import asyncio -import time -from typing import Union -import json - -from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse -import uvicorn - - -app = FastAPI() - - -async def text_streamer(args): - context = args["prompt"] - words = context.split(" ") - for word in words: - await asyncio.sleep(1) - print("word:", word) - ret = { - "text": word + " ", - "error": 0, - } - yield (json.dumps(ret) + "\0").encode("utf-8") - - -@app.post("/") -async def read_root(request: Request): - args = await request.json() - return StreamingResponse(text_streamer(args)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=10002) - args = parser.parse_args() - - uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/requirements.txt b/requirements.txt index bcb79da5..e84873ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ transformers >= 4.28.0 # Required for LLaMA. xformers >= 0.0.19 fastapi uvicorn +pydantic # Required for OpenAI server. diff --git a/test_cli_client.py b/test_cli_client.py deleted file mode 100644 index 217f8088..00000000 --- a/test_cli_client.py +++ /dev/null @@ -1,23 +0,0 @@ -import requests -import json - -def http_request(): - prompt = "Ion Stoica is a" - - headers = {"User-Agent": "Test Client"} - pload = { - "prompt": prompt, - "n": 4, - "use_beam_search": True, - "temperature": 0.0, - } - response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True) - - for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): - if chunk: - data = json.loads(chunk.decode("utf-8")) - output = data["text"] - yield output - -for h in http_request(): - print(h, flush=True)