2023-05-20 13:06:59 -07:00
|
|
|
import argparse
|
2023-05-21 17:04:18 -07:00
|
|
|
import dataclasses
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import Optional, Tuple
|
2023-05-20 13:06:59 -07:00
|
|
|
|
|
|
|
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
|
|
|
SchedulerConfig)
|
|
|
|
|
|
|
|
|
2023-05-21 17:04:18 -07:00
|
|
|
@dataclass
|
|
|
|
class ServerArgs:
|
|
|
|
model: str
|
|
|
|
download_dir: Optional[str] = None
|
|
|
|
use_np_weights: bool = False
|
|
|
|
use_dummy_weights: bool = False
|
2023-06-07 00:40:21 -07:00
|
|
|
dtype: str = "auto"
|
2023-05-21 17:04:18 -07:00
|
|
|
seed: int = 0
|
2023-06-05 23:44:50 +08:00
|
|
|
worker_use_ray: bool = False
|
2023-05-21 17:04:18 -07:00
|
|
|
pipeline_parallel_size: int = 1
|
|
|
|
tensor_parallel_size: int = 1
|
|
|
|
block_size: int = 16
|
|
|
|
swap_space: int = 4 # GiB
|
|
|
|
gpu_memory_utilization: float = 0.95
|
|
|
|
max_num_batched_tokens: int = 2560
|
|
|
|
max_num_seqs: int = 256
|
|
|
|
disable_log_stats: bool = False
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2023-05-21 17:04:18 -07:00
|
|
|
def __post_init__(self):
|
|
|
|
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def add_cli_args(
|
|
|
|
parser: argparse.ArgumentParser,
|
|
|
|
) -> argparse.ArgumentParser:
|
2023-06-05 23:44:50 +08:00
|
|
|
"""Shared CLI arguments for CacheFlow servers."""
|
|
|
|
# Model arguments
|
|
|
|
parser.add_argument('--model', type=str, default='facebook/opt-125m',
|
|
|
|
help='name or path of the huggingface model to use')
|
|
|
|
parser.add_argument('--download-dir', type=str,
|
|
|
|
default=ServerArgs.download_dir,
|
|
|
|
help='directory to download and load the weights, '
|
|
|
|
'default to the default cache dir of '
|
|
|
|
'huggingface')
|
|
|
|
parser.add_argument('--use-np-weights', action='store_true',
|
|
|
|
help='save a numpy copy of model weights for '
|
|
|
|
'faster loading. This can increase the disk '
|
|
|
|
'usage by up to 2x.')
|
|
|
|
parser.add_argument('--use-dummy-weights', action='store_true',
|
|
|
|
help='use dummy values for model weights')
|
|
|
|
# TODO(woosuk): Support FP32.
|
|
|
|
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
|
2023-06-07 00:40:21 -07:00
|
|
|
choices=['auto', 'half', 'bfloat16', 'float'],
|
2023-06-05 23:44:50 +08:00
|
|
|
help='data type for model weights and activations. '
|
2023-06-07 00:40:21 -07:00
|
|
|
'The "auto" option will use FP16 precision '
|
2023-06-05 23:44:50 +08:00
|
|
|
'for FP32 and FP16 models, and BF16 precision '
|
|
|
|
'for BF16 models.')
|
|
|
|
# Parallel arguments
|
|
|
|
parser.add_argument('--worker-use-ray', action='store_true',
|
|
|
|
help='use Ray for distributed serving, will be '
|
|
|
|
'automatically set when using more than 1 GPU')
|
|
|
|
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
|
|
|
|
default=ServerArgs.pipeline_parallel_size,
|
|
|
|
help='number of pipeline stages')
|
|
|
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
|
|
|
|
default=ServerArgs.tensor_parallel_size,
|
|
|
|
help='number of tensor parallel replicas')
|
|
|
|
# KV cache arguments
|
|
|
|
parser.add_argument('--block-size', type=int,
|
|
|
|
default=ServerArgs.block_size,
|
2023-06-07 00:40:21 -07:00
|
|
|
choices=[8, 16, 32],
|
2023-06-05 23:44:50 +08:00
|
|
|
help='token block size')
|
|
|
|
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
|
|
|
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
|
|
|
|
help='random seed')
|
|
|
|
parser.add_argument('--swap-space', type=int,
|
|
|
|
default=ServerArgs.swap_space,
|
|
|
|
help='CPU swap space size (GiB) per GPU')
|
|
|
|
parser.add_argument('--gpu-memory-utilization', type=float,
|
|
|
|
default=ServerArgs.gpu_memory_utilization,
|
|
|
|
help='the percentage of GPU memory to be used for'
|
|
|
|
'the model executor')
|
|
|
|
parser.add_argument('--max-num-batched-tokens', type=int,
|
|
|
|
default=ServerArgs.max_num_batched_tokens,
|
|
|
|
help='maximum number of batched tokens per '
|
|
|
|
'iteration')
|
|
|
|
parser.add_argument('--max-num-seqs', type=int,
|
|
|
|
default=ServerArgs.max_num_seqs,
|
|
|
|
help='maximum number of sequences per iteration')
|
|
|
|
parser.add_argument('--disable-log-stats', action='store_true',
|
|
|
|
help='disable logging statistics')
|
|
|
|
return parser
|
2023-05-21 17:04:18 -07:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
|
|
|
|
# Get the list of attributes of this dataclass.
|
|
|
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
|
|
|
# Set the attributes from the parsed arguments.
|
|
|
|
server_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
|
|
|
return server_args
|
|
|
|
|
|
|
|
def create_server_configs(
|
|
|
|
self,
|
|
|
|
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
|
|
|
# Initialize the configs.
|
|
|
|
model_config = ModelConfig(
|
|
|
|
self.model, self.download_dir, self.use_np_weights,
|
|
|
|
self.use_dummy_weights, self.dtype, self.seed)
|
|
|
|
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
|
|
|
|
self.swap_space)
|
|
|
|
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
|
|
|
self.tensor_parallel_size,
|
2023-06-05 23:44:50 +08:00
|
|
|
self.worker_use_ray)
|
2023-05-21 17:04:18 -07:00
|
|
|
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
|
|
|
self.max_num_seqs)
|
|
|
|
return model_config, cache_config, parallel_config, scheduler_config
|
|
|
|
|
|
|
|
|
2023-06-05 23:44:50 +08:00
|
|
|
@dataclass
|
|
|
|
class AsyncServerArgs(ServerArgs):
|
|
|
|
server_use_ray: bool = False
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def add_cli_args(
|
|
|
|
parser: argparse.ArgumentParser,
|
|
|
|
) -> argparse.ArgumentParser:
|
|
|
|
parser = ServerArgs.add_cli_args(parser)
|
|
|
|
parser.add_argument('--server-use-ray', action='store_true',
|
|
|
|
help='use Ray to start the LLM server in a '
|
|
|
|
'separate process as the web server process.')
|
|
|
|
return parser
|