Add docstrings for LLMServer and related classes and examples (#142)

This commit is contained in:
Zhuohan Li 2023-06-07 18:25:20 +08:00 committed by GitHub
parent e38074b1e6
commit 4298374265
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 212 additions and 18 deletions

View File

@ -12,6 +12,20 @@ _GiB = 1 << 30
class ModelConfig:
"""Configuration for the model.
Args:
model: Name or path of the huggingface model to use.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
use_np_weights: Save a numpy copy of model weights for faster loading.
This can increase the disk usage by up to 2x.
use_dummy_weights: Use dummy values for model weights (for profiling).
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
seed: Random seed for reproducibility.
"""
def __init__(
self,
@ -68,7 +82,14 @@ class ModelConfig:
class CacheConfig:
"""Configuration for the KV cache.
Args:
block_size: Size of a cache block in number of tokens.
gpu_memory_utilization: Fraction of GPU memory to use for the
CacheFlow execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
"""
def __init__(
self,
block_size: int,
@ -111,7 +132,15 @@ class CacheConfig:
class ParallelConfig:
"""Configuration for the distributed execution.
Args:
pipeline_parallel_size: Number of pipeline parallel groups.
tensor_parallel_size: Number of tensor parallel groups.
worker_use_ray: Whether to use Ray for model workers. Will be set to
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
"""
def __init__(
self,
pipeline_parallel_size: int,
@ -134,7 +163,14 @@ class ParallelConfig:
class SchedulerConfig:
"""Scheduler configuration.
Args:
max_num_batched_tokens: Maximum number of tokens to be processed in
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
iteration.
"""
def __init__(
self,
max_num_batched_tokens: int,

View File

@ -96,6 +96,18 @@ def create_logprobs(token_ids: List[int],
@app.post("/v1/completions")
async def create_completion(raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
- echo (since the cacheflow server does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported in cacheflow server)
"""
request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}")

View File

@ -18,6 +18,12 @@ app = FastAPI()
@app.post("/generate")
async def generate_stream(request: Request) -> StreamingResponse:
""" Stream the results of the generation request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
sampling_params = SamplingParams(**request_dict)

View File

@ -9,6 +9,7 @@ from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
@dataclass
class ServerArgs:
"""Arguments for CacheFlow servers."""
model: str
download_dir: Optional[str] = None
use_np_weights: bool = False
@ -117,6 +118,7 @@ class ServerArgs:
@dataclass
class AsyncServerArgs(ServerArgs):
"""Arguments for asynchronous CacheFlow servers."""
server_use_ray: bool = False
@staticmethod

View File

@ -1,6 +1,6 @@
import asyncio
import time
from typing import Dict, Optional
from typing import Dict, List, Optional
from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
@ -15,7 +15,25 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
class AsyncLLMServer:
"""An asynchronous wrapper for LLMServer.
This class is used to wrap the LLMServer class to make it asynchronous. It
uses asyncio to create a background loop that keeps processing incoming
requests. The LLMServer is kicked by the generate method when there
are requests in the waiting queue. The generate method yields the outputs
from the LLMServer to the caller.
NOTE: For the comprehensive list of arguments, see `LLMServer`.
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
server_use_ray: Whether to make LLMServer a Ray actor. If so, the
async frontend will be executed in a separate process as the
model workers.
*args, *kwargs: Arguments for LLMServer.
"""
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
*args, **kwargs) -> None:
self.worker_use_ray = worker_use_ray
@ -35,6 +53,7 @@ class AsyncLLMServer:
self.kicking_request_id: Optional[str] = None
async def server_step(self, kicking_request_id: Optional[str] = None):
"""Kick the server to process the waiting requests."""
self.is_server_running = True
self.kicking_request_id = kicking_request_id
if self.server_use_ray:
@ -54,8 +73,31 @@ class AsyncLLMServer:
self.request_outputs[request_id] = request_output
self.request_events[request_id].set()
async def generate(self, prompt: str, sampling_params: SamplingParams,
request_id: str) -> RequestOutput:
async def generate(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None
) -> RequestOutput:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMServer and streams the outputs
from the LLMServer to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
Yields:
The output `RequestOutput` objects from the LLMServer for the
request.
"""
# Preprocess the request.
arrival_time = time.time()
@ -66,20 +108,29 @@ class AsyncLLMServer:
logger.info(f"Received request {request_id}: "
f"prompt: {prompt!r}, "
f"sampling params: {sampling_params}.")
f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.")
# Add the request into the cacheflow server's waiting queue.
if self.server_use_ray:
await self.server.add_request.remote(
request_id, prompt, sampling_params, arrival_time=arrival_time)
request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
else:
self.server.add_request(
request_id, prompt, sampling_params, arrival_time=arrival_time)
request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests.
while True:
if request_id not in self.request_events:
# The request has been aborted.
return
# Kick the server if the server is not running.
if not self.is_server_running:
await self.server_step(request_id)
@ -113,6 +164,14 @@ class AsyncLLMServer:
break
async def abort(self, request_id: str) -> None:
"""Abort a request.
Abort a submitted request. If the request is finished or not found,
this method will be a no-op.
Args:
request_id: The unique id of the request.
"""
if request_id not in self.request_events:
# The request has already finished or been aborted.
return
@ -137,6 +196,7 @@ class AsyncLLMServer:
@classmethod
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer":
"""Creates an async LLM server from the server arguments."""
# Create the server configs.
server_configs = server_args.create_server_configs()
parallel_config = server_configs[2]

View File

@ -8,7 +8,7 @@ from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import ray, initialize_cluster
from cacheflow.server.ray_utils import DeviceID, initialize_cluster, ray
from cacheflow.server.tokenizer_utils import (get_tokenizer,
detokenize_incrementally)
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
@ -19,6 +19,33 @@ logger = init_logger(__name__)
class LLMServer:
"""An LLM server that receives requests and generates texts.
This is the main class for the CacheFlow LLM server. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMServer` class wraps this class for online serving.
NOTE: The config arguments are derived from the `ServerArgs` class. For the
comprehensive list of arguments, see `ServerArgs`.
Args:
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
distributed_init_method: The initialization method for distributed
execution. See `torch.distributed.init_process_group` for details.
stage_devices: The list of devices for each stage. Each stage is a list
of (rank, node_resource, device) tuples.
log_stats: Whether to log statistics.
"""
def __init__(
self,
@ -27,7 +54,7 @@ class LLMServer:
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
distributed_init_method: str,
stage_devices: List[List[Any]],
stage_devices: List[List[DeviceID]],
log_stats: bool,
) -> None:
logger.info(
@ -83,6 +110,7 @@ class LLMServer:
self.cache_config.verify_with_parallel_config(self.parallel_config)
def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache."""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers(
"profile_num_available_blocks",
@ -108,6 +136,7 @@ class LLMServer:
@classmethod
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
"""Creates an LLM server from the server arguments."""
# Create the server configs.
server_configs = server_args.create_server_configs()
parallel_config = server_configs[2]
@ -126,6 +155,22 @@ class LLMServer:
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
) -> None:
"""Add a request to the server's request pool.
The request is added to the request pool and will be processed by the
scheduler as `server.step()` is called. The exact scheduling policy is
determined by the scheduler.
Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters for text generation.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current time.
"""
if arrival_time is None:
arrival_time = time.time()
if prompt_token_ids is None:
@ -148,15 +193,30 @@ class LLMServer:
self.scheduler.add_seq_group(seq_group)
def abort_request(self, request_id: str) -> None:
"""Aborts a request with the given ID.
Args:
request_id: The ID of the request to abort.
"""
self.scheduler.abort_seq_group(request_id)
def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups()
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()
def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration for the server. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
# Nothing to do.
@ -188,7 +248,7 @@ class LLMServer:
return request_outputs
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Decode the sequence outputs.
"""Decodes the sequence outputs."""
for seq_group in seq_groups:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_token, new_output_text = detokenize_incrementally(
@ -201,7 +261,7 @@ class LLMServer:
seq.output_text = new_output_text
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Stop the sequences.
"""Stop the finished sequences."""
for seq_group in seq_groups:
sampling_params = seq_group.sampling_params
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
@ -238,6 +298,7 @@ class LLMServer:
*args,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
all_outputs = []
for worker in self.workers:
executor = getattr(worker, method)

View File

@ -14,15 +14,30 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
def initialize_cluster(
parallel_config: ParallelConfig,
server_use_ray: bool = False,
address: Optional[str] = None,
ray_server_address: Optional[str] = None,
) -> Tuple[str, List[List[DeviceID]]]:
"""Initialize the distributed cluster probably with Ray.
Args:
parallel_config: The configurations for parallel execution.
server_use_ray: Whether to use Ray for async server.
ray_server_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
Returns:
A tuple of (`distributed_init_method`, `all_stage_devices`). The
`distributed_init_method` is the address for initializing the
distributed backend. `all_stage_devices` includes device IDs for
each worker in each pipeline stage. Each device ID is a tuple of
(rank, node resource, device id).
"""
if parallel_config.worker_use_ray or server_use_ray:
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=address)
ray.init(address=ray_server_address)
if not parallel_config.worker_use_ray:
# Initialize cluster locally.

View File

@ -15,6 +15,7 @@ def get_tokenizer(
*args,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
config = AutoConfig.from_pretrained(model_name)
if config.model_type == "llama" and getattr(kwargs, "use_fast", True):
# LLaMA fast tokenizer causes protobuf errors in some environments.

View File

@ -1,14 +1,15 @@
import openai
# Modify OpenAI's API key and API base to use CacheFlow's API server.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
model = "facebook/opt-125m"
# list models
# Test list models API
models = openai.Model.list()
print(models)
# create a completion
print("Models:", models)
# Test completion API
stream = True
completion = openai.Completion.create(
model=model, prompt="A robot may not injure a human being", echo=False, n=2,
@ -19,4 +20,4 @@ if stream:
for c in completion:
print(c)
else:
print("completion:", completion)
print("Completion result:", completion)

View File

@ -19,7 +19,7 @@ def main(args: argparse.Namespace):
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
]
# Run the server.
# Run the server by calling `server.step()` manually.
request_id = 0
while True:
# To test iteration-level scheduling, we add one request at each step.