diff --git a/benchmarks/benchmark_async_llm_server.py b/benchmarks/benchmark_async_llm_server.py index 3fbd8322..161c6c59 100644 --- a/benchmarks/benchmark_async_llm_server.py +++ b/benchmarks/benchmark_async_llm_server.py @@ -52,7 +52,7 @@ def main(args: argparse.Namespace): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=8001) + parser.add_argument("--port", type=int, default=8000) parser.add_argument("--max-tokens", type=int, default=128) parser.add_argument("--n-threads", type=int, default=128) args = parser.parse_args() diff --git a/cacheflow/__init__.py b/cacheflow/__init__.py index cdc0c183..6e222c9c 100644 --- a/cacheflow/__init__.py +++ b/cacheflow/__init__.py @@ -2,7 +2,7 @@ from cacheflow.entrypoints.llm import LLM from cacheflow.outputs import RequestOutput, CompletionOutput from cacheflow.sampling_params import SamplingParams from cacheflow.server.arg_utils import ServerArgs -from cacheflow.server.llm_server import LLMServer +from cacheflow.server.llm_server import LLMEngine from cacheflow.server.ray_utils import initialize_cluster __version__ = "0.1.0" @@ -12,7 +12,7 @@ __all__ = [ "SamplingParams", "RequestOutput", "CompletionOutput", - "LLMServer", + "LLMEngine", "ServerArgs", "initialize_cluster", ] diff --git a/cacheflow/entrypoints/simple_fastapi_frontend.py b/cacheflow/entrypoints/api_server.py similarity index 92% rename from cacheflow/entrypoints/simple_fastapi_frontend.py rename to cacheflow/entrypoints/api_server.py index 07933003..baff56b9 100644 --- a/cacheflow/entrypoints/simple_fastapi_frontend.py +++ b/cacheflow/entrypoints/api_server.py @@ -8,7 +8,7 @@ import uvicorn from cacheflow.sampling_params import SamplingParams from cacheflow.server.arg_utils import AsyncServerArgs -from cacheflow.server.async_llm_server import AsyncLLMServer +from cacheflow.server.async_llm_server import AsyncLLMEngine from cacheflow.utils import random_uuid TIMEOUT_KEEP_ALIVE = 5 # seconds. @@ -18,7 +18,7 @@ app = FastAPI() @app.post("/generate") async def generate(request: Request) -> Response: - """ Stream the results of the generation request. + """Generate completion for the request. The request should be a JSON object with the following fields: - prompt: the prompt to use for the generation. @@ -74,12 +74,12 @@ async def generate(request: Request) -> Response: if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=8001) + parser.add_argument("--port", type=int, default=8000) parser = AsyncServerArgs.add_cli_args(parser) args = parser.parse_args() server_args = AsyncServerArgs.from_cli_args(args) - server = AsyncLLMServer.from_server_args(server_args) + server = AsyncLLMEngine.from_server_args(server_args) uvicorn.run(app, host=args.host, port=args.port, log_level="debug", timeout_keep_alive=TIMEOUT_KEEP_ALIVE) diff --git a/cacheflow/entrypoints/llm.py b/cacheflow/entrypoints/llm.py index f61e16b2..836cd700 100644 --- a/cacheflow/entrypoints/llm.py +++ b/cacheflow/entrypoints/llm.py @@ -6,7 +6,7 @@ from tqdm import tqdm from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams from cacheflow.server.arg_utils import ServerArgs -from cacheflow.server.llm_server import LLMServer +from cacheflow.server.llm_server import LLMEngine from cacheflow.utils import Counter @@ -20,7 +20,7 @@ class LLM: mechanism and efficient memory management. NOTE: This class is intended to be used for offline inference. For online - serving, use the `AsyncLLMServer` class instead. + serving, use the `AsyncLLMEngine` class instead. NOTE: For the comprehensive list of arguments, see `ServerArgs`. Args: @@ -52,7 +52,7 @@ class LLM: seed=seed, **kwargs, ) - self.llm_server = LLMServer.from_server_args(server_args) + self.llm_server = LLMEngine.from_server_args(server_args) self.request_counter = Counter() def get_tokenizer( diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/api_server.py similarity index 99% rename from cacheflow/entrypoints/openai/openai_frontend.py rename to cacheflow/entrypoints/openai/api_server.py index 125537fb..62fa4b8d 100644 --- a/cacheflow/entrypoints/openai/openai_frontend.py +++ b/cacheflow/entrypoints/openai/api_server.py @@ -15,7 +15,7 @@ import uvicorn from cacheflow.outputs import RequestOutput from cacheflow.server.arg_utils import AsyncServerArgs -from cacheflow.server.async_llm_server import AsyncLLMServer +from cacheflow.server.async_llm_server import AsyncLLMEngine from cacheflow.server.tokenizer_utils import get_tokenizer from cacheflow.logger import init_logger from cacheflow.sampling_params import SamplingParams @@ -319,7 +319,7 @@ if __name__ == "__main__": served_model = args.served_model_name or args.model server_args = AsyncServerArgs.from_cli_args(args) - server = AsyncLLMServer.from_server_args(server_args) + server = AsyncLLMEngine.from_server_args(server_args) # A separate tokenizer to map token IDs to strings. tokenizer = get_tokenizer(args.model) diff --git a/cacheflow/server/async_llm_server.py b/cacheflow/server/async_llm_server.py index 92a42f95..e8e8e7b9 100644 --- a/cacheflow/server/async_llm_server.py +++ b/cacheflow/server/async_llm_server.py @@ -6,7 +6,7 @@ from cacheflow.logger import init_logger from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams from cacheflow.server.arg_utils import AsyncServerArgs -from cacheflow.server.llm_server import LLMServer +from cacheflow.server.llm_server import LLMEngine from cacheflow.server.ray_utils import ray, initialize_cluster logger = init_logger(__name__) @@ -14,26 +14,26 @@ logger = init_logger(__name__) TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds -class AsyncLLMServer: - """An asynchronous wrapper for LLMServer. +class AsyncLLMEngine: + """An asynchronous wrapper for LLMEngine. - This class is used to wrap the LLMServer class to make it asynchronous. It + This class is used to wrap the LLMEngine class to make it asynchronous. It uses asyncio to create a background loop that keeps processing incoming - requests. The LLMServer is kicked by the generate method when there + requests. The LLMEngine is kicked by the generate method when there are requests in the waiting queue. The generate method yields the outputs - from the LLMServer to the caller. + from the LLMEngine to the caller. - NOTE: For the comprehensive list of arguments, see `LLMServer`. + NOTE: For the comprehensive list of arguments, see `LLMEngine`. Args: worker_use_ray: Whether to use Ray for model workers. Required for distributed execution. Should be the same as `parallel_config.worker_use_ray`. - server_use_ray: Whether to make LLMServer a Ray actor. If so, the + server_use_ray: Whether to make LLMEngine a Ray actor. If so, the async frontend will be executed in a separate process as the model workers. log_requests: Whether to log the requests. - *args, *kwargs: Arguments for LLMServer. + *args, *kwargs: Arguments for LLMEngine. """ def __init__(self, worker_use_ray: bool, server_use_ray: bool, log_requests: bool = True, *args, **kwargs) -> None: @@ -41,11 +41,11 @@ class AsyncLLMServer: self.server_use_ray = server_use_ray self.log_requests = log_requests if not self.server_use_ray: - server_class = LLMServer + server_class = LLMEngine elif self.worker_use_ray: - server_class = ray.remote(num_cpus=0)(LLMServer).remote + server_class = ray.remote(num_cpus=0)(LLMEngine).remote else: - server_class = ray.remote(num_gpus=1)(LLMServer).remote + server_class = ray.remote(num_gpus=1)(LLMEngine).remote self.server = server_class(*args, **kwargs) # Request id -> request output. self.request_outputs: Dict[str, RequestOutput] = {} @@ -85,8 +85,8 @@ class AsyncLLMServer: """Generate outputs for a request. Generate outputs for a request. This method is a coroutine. It adds the - request into the waiting queue of the LLMServer and streams the outputs - from the LLMServer to the caller. + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. Args: prompt: The prompt string. Can be None if prompt_token_ids is @@ -97,7 +97,7 @@ class AsyncLLMServer: use the tokenizer to convert the prompts to token IDs. Yields: - The output `RequestOutput` objects from the LLMServer for the + The output `RequestOutput` objects from the LLMEngine for the request. """ # Preprocess the request. @@ -200,7 +200,7 @@ class AsyncLLMServer: self.kicking_request_id = None @classmethod - def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer": + def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMEngine": """Creates an async LLM server from the server arguments.""" # Create the server configs. server_configs = server_args.create_server_configs() diff --git a/cacheflow/server/llm_server.py b/cacheflow/server/llm_server.py index 6a9107f8..c3a9d943 100644 --- a/cacheflow/server/llm_server.py +++ b/cacheflow/server/llm_server.py @@ -18,7 +18,7 @@ from cacheflow.worker.worker import Worker logger = init_logger(__name__) -class LLMServer: +class LLMEngine: """An LLM server that receives requests and generates texts. This is the main class for the CacheFlow LLM server. It receives requests @@ -29,7 +29,7 @@ class LLMServer: serving throughput. The `LLM` class wraps this class for offline batched inference and the - `AsyncLLMServer` class wraps this class for online serving. + `AsyncLLMEngine` class wraps this class for online serving. NOTE: The config arguments are derived from the `ServerArgs` class. For the comprehensive list of arguments, see `ServerArgs`. @@ -135,7 +135,7 @@ class LLMServer: self._run_workers("init_cache_engine", cache_config=self.cache_config) @classmethod - def from_server_args(cls, server_args: ServerArgs) -> "LLMServer": + def from_server_args(cls, server_args: ServerArgs) -> "LLMEngine": """Creates an LLM server from the server arguments.""" # Create the server configs. server_configs = server_args.create_server_configs() diff --git a/examples/simple_fastapi_client.py b/examples/api_client.py similarity index 94% rename from examples/simple_fastapi_client.py rename to examples/api_client.py index 45258cf6..53571444 100644 --- a/examples/simple_fastapi_client.py +++ b/examples/api_client.py @@ -1,3 +1,5 @@ +"""Example Python client for cacheflow.entrypoints.api_server""" + import argparse import json from typing import Iterable, List @@ -45,7 +47,7 @@ def get_response(response: requests.Response) -> List[str]: if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=8001) + parser.add_argument("--port", type=int, default=8000) parser.add_argument("--n", type=int, default=4) parser.add_argument("--prompt", type=str, default="San Francisco is a") parser.add_argument("--stream", action="store_true") diff --git a/examples/gradio_webserver.py b/examples/gradio_webserver.py index e4a80c39..107e8714 100644 --- a/examples/gradio_webserver.py +++ b/examples/gradio_webserver.py @@ -9,6 +9,7 @@ def http_bot(prompt): headers = {"User-Agent": "Cacheflow Client"} pload = { "prompt": prompt, + "stream": True, "max_tokens": 128, } response = requests.post(args.model_url, headers=headers, json=pload, stream=True) @@ -34,8 +35,8 @@ def build_demo(): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=8002) - parser.add_argument("--model-url", type=str, default="http://localhost:8001/generate") + parser.add_argument("--port", type=int, default=8001) + parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate") args = parser.parse_args() demo = build_demo() diff --git a/examples/simple_server.py b/examples/llmserver_example.py similarity index 86% rename from examples/simple_server.py rename to examples/llmserver_example.py index d43e1bc8..d7f3777d 100644 --- a/examples/simple_server.py +++ b/examples/llmserver_example.py @@ -1,12 +1,12 @@ import argparse -from cacheflow import ServerArgs, LLMServer, SamplingParams +from cacheflow import ServerArgs, LLMEngine, SamplingParams def main(args: argparse.Namespace): # Parse the CLI argument and initialize the server. server_args = ServerArgs.from_cli_args(args) - server = LLMServer.from_server_args(server_args) + server = LLMEngine.from_server_args(server_args) # Test the following prompts. test_prompts = [ @@ -38,7 +38,8 @@ def main(args: argparse.Namespace): if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Simple CacheFlow server.') + parser = argparse.ArgumentParser( + description='Demo on using the LLMEngine class synchronously') parser = ServerArgs.add_cli_args(parser) args = parser.parse_args() main(args)