Add docstrings for LLMServer and related classes and examples (#142)
This commit is contained in:
parent
e38074b1e6
commit
4298374265
@ -12,6 +12,20 @@ _GiB = 1 << 30
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""Configuration for the model.
|
||||
|
||||
Args:
|
||||
model: Name or path of the huggingface model to use.
|
||||
download_dir: Directory to download and load the weights, default to the
|
||||
default cache directory of huggingface.
|
||||
use_np_weights: Save a numpy copy of model weights for faster loading.
|
||||
This can increase the disk usage by up to 2x.
|
||||
use_dummy_weights: Use dummy values for model weights (for profiling).
|
||||
dtype: Data type for model weights and activations. The "auto" option
|
||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||
for BF16 models.
|
||||
seed: Random seed for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -68,7 +82,14 @@ class ModelConfig:
|
||||
|
||||
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache.
|
||||
|
||||
Args:
|
||||
block_size: Size of a cache block in number of tokens.
|
||||
gpu_memory_utilization: Fraction of GPU memory to use for the
|
||||
CacheFlow execution.
|
||||
swap_space: Size of the CPU swap space per GPU (in GiB).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
@ -111,7 +132,15 @@ class CacheConfig:
|
||||
|
||||
|
||||
class ParallelConfig:
|
||||
"""Configuration for the distributed execution.
|
||||
|
||||
Args:
|
||||
pipeline_parallel_size: Number of pipeline parallel groups.
|
||||
tensor_parallel_size: Number of tensor parallel groups.
|
||||
worker_use_ray: Whether to use Ray for model workers. Will be set to
|
||||
True if either pipeline_parallel_size or tensor_parallel_size is
|
||||
greater than 1.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_parallel_size: int,
|
||||
@ -134,7 +163,14 @@ class ParallelConfig:
|
||||
|
||||
|
||||
class SchedulerConfig:
|
||||
"""Scheduler configuration.
|
||||
|
||||
Args:
|
||||
max_num_batched_tokens: Maximum number of tokens to be processed in
|
||||
a single iteration.
|
||||
max_num_seqs: Maximum number of sequences to be processed in a single
|
||||
iteration.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
|
@ -96,6 +96,18 @@ def create_logprobs(token_ids: List[int],
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def create_completion(raw_request: Request):
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following features:
|
||||
- echo (since the cacheflow server does not currently support
|
||||
getting the logprobs of prompt tokens)
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
- logit_bias (to be supported in cacheflow server)
|
||||
"""
|
||||
request = CompletionRequest(**await raw_request.json())
|
||||
logger.info(f"Received completion request: {request}")
|
||||
|
||||
|
@ -18,6 +18,12 @@ app = FastAPI()
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_stream(request: Request) -> StreamingResponse:
|
||||
""" Stream the results of the generation request.
|
||||
|
||||
The request should be a JSON object with the following fields:
|
||||
- prompt: the prompt to use for the generation.
|
||||
- other fields: the sampling parameters (See `SamplingParams` for details).
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
prompt = request_dict.pop("prompt")
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
|
@ -9,6 +9,7 @@ from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
|
||||
@dataclass
|
||||
class ServerArgs:
|
||||
"""Arguments for CacheFlow servers."""
|
||||
model: str
|
||||
download_dir: Optional[str] = None
|
||||
use_np_weights: bool = False
|
||||
@ -117,6 +118,7 @@ class ServerArgs:
|
||||
|
||||
@dataclass
|
||||
class AsyncServerArgs(ServerArgs):
|
||||
"""Arguments for asynchronous CacheFlow servers."""
|
||||
server_use_ray: bool = False
|
||||
|
||||
@staticmethod
|
||||
|
@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
@ -15,7 +15,25 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
|
||||
|
||||
class AsyncLLMServer:
|
||||
"""An asynchronous wrapper for LLMServer.
|
||||
|
||||
This class is used to wrap the LLMServer class to make it asynchronous. It
|
||||
uses asyncio to create a background loop that keeps processing incoming
|
||||
requests. The LLMServer is kicked by the generate method when there
|
||||
are requests in the waiting queue. The generate method yields the outputs
|
||||
from the LLMServer to the caller.
|
||||
|
||||
NOTE: For the comprehensive list of arguments, see `LLMServer`.
|
||||
|
||||
Args:
|
||||
worker_use_ray: Whether to use Ray for model workers. Required for
|
||||
distributed execution. Should be the same as
|
||||
`parallel_config.worker_use_ray`.
|
||||
server_use_ray: Whether to make LLMServer a Ray actor. If so, the
|
||||
async frontend will be executed in a separate process as the
|
||||
model workers.
|
||||
*args, *kwargs: Arguments for LLMServer.
|
||||
"""
|
||||
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
|
||||
*args, **kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
@ -35,6 +53,7 @@ class AsyncLLMServer:
|
||||
self.kicking_request_id: Optional[str] = None
|
||||
|
||||
async def server_step(self, kicking_request_id: Optional[str] = None):
|
||||
"""Kick the server to process the waiting requests."""
|
||||
self.is_server_running = True
|
||||
self.kicking_request_id = kicking_request_id
|
||||
if self.server_use_ray:
|
||||
@ -54,8 +73,31 @@ class AsyncLLMServer:
|
||||
self.request_outputs[request_id] = request_output
|
||||
self.request_events[request_id].set()
|
||||
|
||||
async def generate(self, prompt: str, sampling_params: SamplingParams,
|
||||
request_id: str) -> RequestOutput:
|
||||
async def generate(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
) -> RequestOutput:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
request into the waiting queue of the LLMServer and streams the outputs
|
||||
from the LLMServer to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
sampling_params: The sampling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
|
||||
Yields:
|
||||
The output `RequestOutput` objects from the LLMServer for the
|
||||
request.
|
||||
"""
|
||||
# Preprocess the request.
|
||||
arrival_time = time.time()
|
||||
|
||||
@ -66,20 +108,29 @@ class AsyncLLMServer:
|
||||
|
||||
logger.info(f"Received request {request_id}: "
|
||||
f"prompt: {prompt!r}, "
|
||||
f"sampling params: {sampling_params}.")
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
|
||||
# Add the request into the cacheflow server's waiting queue.
|
||||
if self.server_use_ray:
|
||||
await self.server.add_request.remote(
|
||||
request_id, prompt, sampling_params, arrival_time=arrival_time)
|
||||
request_id, prompt, sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
else:
|
||||
self.server.add_request(
|
||||
request_id, prompt, sampling_params, arrival_time=arrival_time)
|
||||
request_id, prompt, sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
|
||||
# The cacheflow server does not have a background loop that keeps
|
||||
# processing incoming requests. Therefore, we need to keep kicking
|
||||
# the server to process the requests.
|
||||
while True:
|
||||
if request_id not in self.request_events:
|
||||
# The request has been aborted.
|
||||
return
|
||||
|
||||
# Kick the server if the server is not running.
|
||||
if not self.is_server_running:
|
||||
await self.server_step(request_id)
|
||||
@ -113,6 +164,14 @@ class AsyncLLMServer:
|
||||
break
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
|
||||
Abort a submitted request. If the request is finished or not found,
|
||||
this method will be a no-op.
|
||||
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
if request_id not in self.request_events:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
@ -137,6 +196,7 @@ class AsyncLLMServer:
|
||||
|
||||
@classmethod
|
||||
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer":
|
||||
"""Creates an async LLM server from the server arguments."""
|
||||
# Create the server configs.
|
||||
server_configs = server_args.create_server_configs()
|
||||
parallel_config = server_configs[2]
|
||||
|
@ -8,7 +8,7 @@ from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.ray_utils import ray, initialize_cluster
|
||||
from cacheflow.server.ray_utils import DeviceID, initialize_cluster, ray
|
||||
from cacheflow.server.tokenizer_utils import (get_tokenizer,
|
||||
detokenize_incrementally)
|
||||
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
@ -19,6 +19,33 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LLMServer:
|
||||
"""An LLM server that receives requests and generates texts.
|
||||
|
||||
This is the main class for the CacheFlow LLM server. It receives requests
|
||||
from clients and generates texts from the LLM. It includes a tokenizer, a
|
||||
language model (possibly distributed across multiple GPUs), and GPU memory
|
||||
space allocated for intermediate states (aka KV cache). This class utilizes
|
||||
iteration-level scheduling and efficient memory management to maximize the
|
||||
serving throughput.
|
||||
|
||||
The `LLM` class wraps this class for offline batched inference and the
|
||||
`AsyncLLMServer` class wraps this class for online serving.
|
||||
|
||||
NOTE: The config arguments are derived from the `ServerArgs` class. For the
|
||||
comprehensive list of arguments, see `ServerArgs`.
|
||||
|
||||
Args:
|
||||
model_config: The configuration related to the LLM model.
|
||||
cache_config: The configuration related to the KV cache memory
|
||||
management.
|
||||
parallel_config: The configuration related to distributed execution.
|
||||
scheduler_config: The configuration related to the request scheduler.
|
||||
distributed_init_method: The initialization method for distributed
|
||||
execution. See `torch.distributed.init_process_group` for details.
|
||||
stage_devices: The list of devices for each stage. Each stage is a list
|
||||
of (rank, node_resource, device) tuples.
|
||||
log_stats: Whether to log statistics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -27,7 +54,7 @@ class LLMServer:
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
distributed_init_method: str,
|
||||
stage_devices: List[List[Any]],
|
||||
stage_devices: List[List[DeviceID]],
|
||||
log_stats: bool,
|
||||
) -> None:
|
||||
logger.info(
|
||||
@ -83,6 +110,7 @@ class LLMServer:
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
def _init_cache(self) -> None:
|
||||
"""Profiles the memory usage and initializes the KV cache."""
|
||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||
num_blocks = self._run_workers(
|
||||
"profile_num_available_blocks",
|
||||
@ -108,6 +136,7 @@ class LLMServer:
|
||||
|
||||
@classmethod
|
||||
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
|
||||
"""Creates an LLM server from the server arguments."""
|
||||
# Create the server configs.
|
||||
server_configs = server_args.create_server_configs()
|
||||
parallel_config = server_configs[2]
|
||||
@ -126,6 +155,22 @@ class LLMServer:
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Add a request to the server's request pool.
|
||||
|
||||
The request is added to the request pool and will be processed by the
|
||||
scheduler as `server.step()` is called. The exact scheduling policy is
|
||||
determined by the scheduler.
|
||||
|
||||
Args:
|
||||
request_id: The unique ID of the request.
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
sampling_params: The sampling parameters for text generation.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
arrival_time: The arrival time of the request. If None, we use
|
||||
the current time.
|
||||
"""
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
if prompt_token_ids is None:
|
||||
@ -148,15 +193,30 @@ class LLMServer:
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
|
||||
def abort_request(self, request_id: str) -> None:
|
||||
"""Aborts a request with the given ID.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request to abort.
|
||||
"""
|
||||
self.scheduler.abort_seq_group(request_id)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
"""Gets the number of unfinished requests."""
|
||||
return self.scheduler.get_num_unfinished_seq_groups()
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
"""Returns True if there are unfinished requests."""
|
||||
return self.scheduler.has_unfinished_seqs()
|
||||
|
||||
def step(self) -> List[RequestOutput]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
|
||||
This function performs one decoding iteration for the server. It first
|
||||
schedules the sequences to be executed in the next iteration and the
|
||||
token blocks to be swapped in/out/copy. Then, it executes the model
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
|
||||
# Nothing to do.
|
||||
@ -188,7 +248,7 @@ class LLMServer:
|
||||
return request_outputs
|
||||
|
||||
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
# Decode the sequence outputs.
|
||||
"""Decodes the sequence outputs."""
|
||||
for seq_group in seq_groups:
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
new_token, new_output_text = detokenize_incrementally(
|
||||
@ -201,7 +261,7 @@ class LLMServer:
|
||||
seq.output_text = new_output_text
|
||||
|
||||
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
# Stop the sequences.
|
||||
"""Stop the finished sequences."""
|
||||
for seq_group in seq_groups:
|
||||
sampling_params = seq_group.sampling_params
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
@ -238,6 +298,7 @@ class LLMServer:
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
all_outputs = []
|
||||
for worker in self.workers:
|
||||
executor = getattr(worker, method)
|
||||
|
@ -14,15 +14,30 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
|
||||
def initialize_cluster(
|
||||
parallel_config: ParallelConfig,
|
||||
server_use_ray: bool = False,
|
||||
address: Optional[str] = None,
|
||||
ray_server_address: Optional[str] = None,
|
||||
) -> Tuple[str, List[List[DeviceID]]]:
|
||||
"""Initialize the distributed cluster probably with Ray.
|
||||
|
||||
Args:
|
||||
parallel_config: The configurations for parallel execution.
|
||||
server_use_ray: Whether to use Ray for async server.
|
||||
ray_server_address: The address of the Ray cluster. If None, uses
|
||||
the default Ray cluster address.
|
||||
|
||||
Returns:
|
||||
A tuple of (`distributed_init_method`, `all_stage_devices`). The
|
||||
`distributed_init_method` is the address for initializing the
|
||||
distributed backend. `all_stage_devices` includes device IDs for
|
||||
each worker in each pipeline stage. Each device ID is a tuple of
|
||||
(rank, node resource, device id).
|
||||
"""
|
||||
if parallel_config.worker_use_ray or server_use_ray:
|
||||
if ray is None:
|
||||
raise ImportError(
|
||||
"Ray is not installed. Please install Ray to use distributed "
|
||||
"serving.")
|
||||
# Connect to a ray cluster.
|
||||
ray.init(address=address)
|
||||
ray.init(address=ray_server_address)
|
||||
|
||||
if not parallel_config.worker_use_ray:
|
||||
# Initialize cluster locally.
|
||||
|
@ -15,6 +15,7 @@ def get_tokenizer(
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
if config.model_type == "llama" and getattr(kwargs, "use_fast", True):
|
||||
# LLaMA fast tokenizer causes protobuf errors in some environments.
|
||||
|
@ -1,14 +1,15 @@
|
||||
import openai
|
||||
|
||||
# Modify OpenAI's API key and API base to use CacheFlow's API server.
|
||||
openai.api_key = "EMPTY"
|
||||
openai.api_base = "http://localhost:8000/v1"
|
||||
model = "facebook/opt-125m"
|
||||
|
||||
# list models
|
||||
# Test list models API
|
||||
models = openai.Model.list()
|
||||
print(models)
|
||||
|
||||
# create a completion
|
||||
print("Models:", models)
|
||||
|
||||
# Test completion API
|
||||
stream = True
|
||||
completion = openai.Completion.create(
|
||||
model=model, prompt="A robot may not injure a human being", echo=False, n=2,
|
||||
@ -19,4 +20,4 @@ if stream:
|
||||
for c in completion:
|
||||
print(c)
|
||||
else:
|
||||
print("completion:", completion)
|
||||
print("Completion result:", completion)
|
||||
|
@ -19,7 +19,7 @@ def main(args: argparse.Namespace):
|
||||
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
|
||||
]
|
||||
|
||||
# Run the server.
|
||||
# Run the server by calling `server.step()` manually.
|
||||
request_id = 0
|
||||
while True:
|
||||
# To test iteration-level scheduling, we add one request at each step.
|
||||
|
Loading…
x
Reference in New Issue
Block a user