Rename servers and change port numbers to reduce confusion (#149)

This commit is contained in:
Zhuohan Li 2023-06-17 00:13:02 +08:00 committed by GitHub
parent 311490a720
commit eedb46bf03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 41 additions and 37 deletions

View File

@ -52,7 +52,7 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--max-tokens", type=int, default=128) parser.add_argument("--max-tokens", type=int, default=128)
parser.add_argument("--n-threads", type=int, default=128) parser.add_argument("--n-threads", type=int, default=128)
args = parser.parse_args() args = parser.parse_args()

View File

@ -2,7 +2,7 @@ from cacheflow.entrypoints.llm import LLM
from cacheflow.outputs import RequestOutput, CompletionOutput from cacheflow.outputs import RequestOutput, CompletionOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer from cacheflow.server.llm_server import LLMEngine
from cacheflow.server.ray_utils import initialize_cluster from cacheflow.server.ray_utils import initialize_cluster
__version__ = "0.1.0" __version__ = "0.1.0"
@ -12,7 +12,7 @@ __all__ = [
"SamplingParams", "SamplingParams",
"RequestOutput", "RequestOutput",
"CompletionOutput", "CompletionOutput",
"LLMServer", "LLMEngine",
"ServerArgs", "ServerArgs",
"initialize_cluster", "initialize_cluster",
] ]

View File

@ -8,7 +8,7 @@ import uvicorn
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import AsyncServerArgs from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer from cacheflow.server.async_llm_server import AsyncLLMEngine
from cacheflow.utils import random_uuid from cacheflow.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_KEEP_ALIVE = 5 # seconds.
@ -18,7 +18,7 @@ app = FastAPI()
@app.post("/generate") @app.post("/generate")
async def generate(request: Request) -> Response: async def generate(request: Request) -> Response:
""" Stream the results of the generation request. """Generate completion for the request.
The request should be a JSON object with the following fields: The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation. - prompt: the prompt to use for the generation.
@ -74,12 +74,12 @@ async def generate(request: Request) -> Response:
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8000)
parser = AsyncServerArgs.add_cli_args(parser) parser = AsyncServerArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
server_args = AsyncServerArgs.from_cli_args(args) server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args) server = AsyncLLMEngine.from_server_args(server_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="debug", uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE) timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -6,7 +6,7 @@ from tqdm import tqdm
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer from cacheflow.server.llm_server import LLMEngine
from cacheflow.utils import Counter from cacheflow.utils import Counter
@ -20,7 +20,7 @@ class LLM:
mechanism and efficient memory management. mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMServer` class instead. serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `ServerArgs`. NOTE: For the comprehensive list of arguments, see `ServerArgs`.
Args: Args:
@ -52,7 +52,7 @@ class LLM:
seed=seed, seed=seed,
**kwargs, **kwargs,
) )
self.llm_server = LLMServer.from_server_args(server_args) self.llm_server = LLMEngine.from_server_args(server_args)
self.request_counter = Counter() self.request_counter = Counter()
def get_tokenizer( def get_tokenizer(

View File

@ -15,7 +15,7 @@ import uvicorn
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.server.arg_utils import AsyncServerArgs from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer from cacheflow.server.async_llm_server import AsyncLLMEngine
from cacheflow.server.tokenizer_utils import get_tokenizer from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.logger import init_logger from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
@ -319,7 +319,7 @@ if __name__ == "__main__":
served_model = args.served_model_name or args.model served_model = args.served_model_name or args.model
server_args = AsyncServerArgs.from_cli_args(args) server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args) server = AsyncLLMEngine.from_server_args(server_args)
# A separate tokenizer to map token IDs to strings. # A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model) tokenizer = get_tokenizer(args.model)

View File

@ -6,7 +6,7 @@ from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import AsyncServerArgs from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.llm_server import LLMServer from cacheflow.server.llm_server import LLMEngine
from cacheflow.server.ray_utils import ray, initialize_cluster from cacheflow.server.ray_utils import ray, initialize_cluster
logger = init_logger(__name__) logger = init_logger(__name__)
@ -14,26 +14,26 @@ logger = init_logger(__name__)
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
class AsyncLLMServer: class AsyncLLMEngine:
"""An asynchronous wrapper for LLMServer. """An asynchronous wrapper for LLMEngine.
This class is used to wrap the LLMServer class to make it asynchronous. It This class is used to wrap the LLMEngine class to make it asynchronous. It
uses asyncio to create a background loop that keeps processing incoming uses asyncio to create a background loop that keeps processing incoming
requests. The LLMServer is kicked by the generate method when there requests. The LLMEngine is kicked by the generate method when there
are requests in the waiting queue. The generate method yields the outputs are requests in the waiting queue. The generate method yields the outputs
from the LLMServer to the caller. from the LLMEngine to the caller.
NOTE: For the comprehensive list of arguments, see `LLMServer`. NOTE: For the comprehensive list of arguments, see `LLMEngine`.
Args: Args:
worker_use_ray: Whether to use Ray for model workers. Required for worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as distributed execution. Should be the same as
`parallel_config.worker_use_ray`. `parallel_config.worker_use_ray`.
server_use_ray: Whether to make LLMServer a Ray actor. If so, the server_use_ray: Whether to make LLMEngine a Ray actor. If so, the
async frontend will be executed in a separate process as the async frontend will be executed in a separate process as the
model workers. model workers.
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
*args, *kwargs: Arguments for LLMServer. *args, *kwargs: Arguments for LLMEngine.
""" """
def __init__(self, worker_use_ray: bool, server_use_ray: bool, def __init__(self, worker_use_ray: bool, server_use_ray: bool,
log_requests: bool = True, *args, **kwargs) -> None: log_requests: bool = True, *args, **kwargs) -> None:
@ -41,11 +41,11 @@ class AsyncLLMServer:
self.server_use_ray = server_use_ray self.server_use_ray = server_use_ray
self.log_requests = log_requests self.log_requests = log_requests
if not self.server_use_ray: if not self.server_use_ray:
server_class = LLMServer server_class = LLMEngine
elif self.worker_use_ray: elif self.worker_use_ray:
server_class = ray.remote(num_cpus=0)(LLMServer).remote server_class = ray.remote(num_cpus=0)(LLMEngine).remote
else: else:
server_class = ray.remote(num_gpus=1)(LLMServer).remote server_class = ray.remote(num_gpus=1)(LLMEngine).remote
self.server = server_class(*args, **kwargs) self.server = server_class(*args, **kwargs)
# Request id -> request output. # Request id -> request output.
self.request_outputs: Dict[str, RequestOutput] = {} self.request_outputs: Dict[str, RequestOutput] = {}
@ -85,8 +85,8 @@ class AsyncLLMServer:
"""Generate outputs for a request. """Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMServer and streams the outputs request into the waiting queue of the LLMEngine and streams the outputs
from the LLMServer to the caller. from the LLMEngine to the caller.
Args: Args:
prompt: The prompt string. Can be None if prompt_token_ids is prompt: The prompt string. Can be None if prompt_token_ids is
@ -97,7 +97,7 @@ class AsyncLLMServer:
use the tokenizer to convert the prompts to token IDs. use the tokenizer to convert the prompts to token IDs.
Yields: Yields:
The output `RequestOutput` objects from the LLMServer for the The output `RequestOutput` objects from the LLMEngine for the
request. request.
""" """
# Preprocess the request. # Preprocess the request.
@ -200,7 +200,7 @@ class AsyncLLMServer:
self.kicking_request_id = None self.kicking_request_id = None
@classmethod @classmethod
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer": def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMEngine":
"""Creates an async LLM server from the server arguments.""" """Creates an async LLM server from the server arguments."""
# Create the server configs. # Create the server configs.
server_configs = server_args.create_server_configs() server_configs = server_args.create_server_configs()

View File

@ -18,7 +18,7 @@ from cacheflow.worker.worker import Worker
logger = init_logger(__name__) logger = init_logger(__name__)
class LLMServer: class LLMEngine:
"""An LLM server that receives requests and generates texts. """An LLM server that receives requests and generates texts.
This is the main class for the CacheFlow LLM server. It receives requests This is the main class for the CacheFlow LLM server. It receives requests
@ -29,7 +29,7 @@ class LLMServer:
serving throughput. serving throughput.
The `LLM` class wraps this class for offline batched inference and the The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMServer` class wraps this class for online serving. `AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `ServerArgs` class. For the NOTE: The config arguments are derived from the `ServerArgs` class. For the
comprehensive list of arguments, see `ServerArgs`. comprehensive list of arguments, see `ServerArgs`.
@ -135,7 +135,7 @@ class LLMServer:
self._run_workers("init_cache_engine", cache_config=self.cache_config) self._run_workers("init_cache_engine", cache_config=self.cache_config)
@classmethod @classmethod
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer": def from_server_args(cls, server_args: ServerArgs) -> "LLMEngine":
"""Creates an LLM server from the server arguments.""" """Creates an LLM server from the server arguments."""
# Create the server configs. # Create the server configs.
server_configs = server_args.create_server_configs() server_configs = server_args.create_server_configs()

View File

@ -1,3 +1,5 @@
"""Example Python client for cacheflow.entrypoints.api_server"""
import argparse import argparse
import json import json
from typing import Iterable, List from typing import Iterable, List
@ -45,7 +47,7 @@ def get_response(response: requests.Response) -> List[str]:
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--n", type=int, default=4) parser.add_argument("--n", type=int, default=4)
parser.add_argument("--prompt", type=str, default="San Francisco is a") parser.add_argument("--prompt", type=str, default="San Francisco is a")
parser.add_argument("--stream", action="store_true") parser.add_argument("--stream", action="store_true")

View File

@ -9,6 +9,7 @@ def http_bot(prompt):
headers = {"User-Agent": "Cacheflow Client"} headers = {"User-Agent": "Cacheflow Client"}
pload = { pload = {
"prompt": prompt, "prompt": prompt,
"stream": True,
"max_tokens": 128, "max_tokens": 128,
} }
response = requests.post(args.model_url, headers=headers, json=pload, stream=True) response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
@ -34,8 +35,8 @@ def build_demo():
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8002) parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--model-url", type=str, default="http://localhost:8001/generate") parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate")
args = parser.parse_args() args = parser.parse_args()
demo = build_demo() demo = build_demo()

View File

@ -1,12 +1,12 @@
import argparse import argparse
from cacheflow import ServerArgs, LLMServer, SamplingParams from cacheflow import ServerArgs, LLMEngine, SamplingParams
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
# Parse the CLI argument and initialize the server. # Parse the CLI argument and initialize the server.
server_args = ServerArgs.from_cli_args(args) server_args = ServerArgs.from_cli_args(args)
server = LLMServer.from_server_args(server_args) server = LLMEngine.from_server_args(server_args)
# Test the following prompts. # Test the following prompts.
test_prompts = [ test_prompts = [
@ -38,7 +38,8 @@ def main(args: argparse.Namespace):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Simple CacheFlow server.') parser = argparse.ArgumentParser(
description='Demo on using the LLMEngine class synchronously')
parser = ServerArgs.add_cli_args(parser) parser = ServerArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)