Introduce LLM class for offline inference (#115)

This commit is contained in:
Woosuk Kwon 2023-05-21 17:04:18 -07:00 committed by GitHub
parent f746ced08d
commit 655a5e48df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 222 additions and 81 deletions

View File

@ -1,19 +1,15 @@
from cacheflow.entrypoints.llm import LLM
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import (
add_server_arguments,
create_server_configs_from_args,
initialize_server_from_args,
)
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster
__all__ = [
"RequestOutput",
"LLM",
"SamplingParams",
"RequestOutput",
"LLMServer",
"add_server_arguments",
"create_server_configs_from_args",
"initialize_server_from_args",
"ServerArgs",
"initialize_cluster",
]

View File

@ -3,6 +3,8 @@ from typing import Optional
import torch
from transformers import AutoConfig, PretrainedConfig
_GiB = 1 << 30
class ModelConfig:
@ -70,7 +72,7 @@ class CacheConfig:
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space = swap_space
self.swap_space_bytes = swap_space * _GiB
# Will be set after profiling.
self.num_gpu_blocks = None
@ -138,6 +140,8 @@ def _get_and_verify_dtype(
else:
torch_dtype = config_dtype
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
# Verify the dtype.

View File

@ -12,8 +12,7 @@ import uvicorn
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import (
add_server_arguments, create_server_configs_from_args)
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster
@ -116,10 +115,10 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=10002)
parser = add_server_arguments(parser)
parser = ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_configs = create_server_configs_from_args(args)
server_configs = ServerArgs.from_cli_args(args).create_server_configs()
parallel_config = server_configs[2]
distributed_init_method, stage_devices = initialize_cluster(parallel_config)

View File

@ -0,0 +1,62 @@
from typing import List, Optional
from tqdm import tqdm
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer
from cacheflow.utils import Counter
class LLM:
def __init__(
self,
model: str,
tensor_parallel_size: int = 1,
dtype: str = "default",
seed: int = 0,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
server_args = ServerArgs(
model=model,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
seed=seed,
**kwargs,
)
self.llm_server = LLMServer.from_server_args(server_args)
self.request_counter = Counter()
def generate(
self,
prompts: List[str],
sampling_params: Optional[SamplingParams] = None,
use_tqdm: bool = True,
) -> List[RequestOutput]:
if sampling_params is None:
sampling_params = SamplingParams()
# Initialize tqdm.
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Processed prompts")
# Add requests to the server.
for prompt in prompts:
request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params)
# Run the server.
outputs: List[RequestOutput] = []
while self.llm_server.has_unfinished_requests():
step_outputs = self.llm_server.step()
for output in step_outputs:
if output.done:
outputs.append(output)
if use_tqdm:
pbar.update(1)
if use_tqdm:
pbar.close()
return outputs

View File

@ -35,7 +35,7 @@ class RequestOutput:
prompt: str,
prompt_token_ids: List[int],
outputs: List[CompletionOutput],
done: bool = False,
done: bool,
) -> None:
self.request_id = request_id
self.prompt = prompt
@ -43,8 +43,8 @@ class RequestOutput:
self.outputs = outputs
self.done = done
@staticmethod
def from_seq_group(seq_group: SequenceGroup) -> "RequestOutput":
@classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
# Get the top-n sequences.
n = seq_group.sampling_params.n
seqs = seq_group.get_seqs()
@ -70,8 +70,8 @@ class RequestOutput:
# Every sequence in the sequence group should have the same prompt.
prompt = top_n_seqs[0].prompt
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
return RequestOutput(seq_group.request_id, prompt, prompt_token_ids,
outputs, seq_group.is_finished())
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
seq_group.is_finished())
def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, "

View File

@ -1,74 +1,117 @@
import argparse
from typing import Tuple
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster
_GiB = 1 << 30
def add_server_arguments(parser: argparse.ArgumentParser):
"""Shared arguments for CacheFlow servers."""
@dataclass
class ServerArgs:
model: str
download_dir: Optional[str] = None
use_np_weights: bool = False
use_dummy_weights: bool = False
dtype: str = "default"
seed: int = 0
use_ray: bool = False
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
block_size: int = 16
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.95
max_num_batched_tokens: int = 2560
max_num_seqs: int = 256
disable_log_stats: bool = False
def __post_init__(self):
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
return _add_server_arguments(parser)
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
server_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return server_args
def create_server_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs.
model_config = ModelConfig(
self.model, self.download_dir, self.use_np_weights,
self.use_dummy_weights, self.dtype, self.seed)
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
self.swap_space)
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.use_ray)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs)
return model_config, cache_config, parallel_config, scheduler_config
def _add_server_arguments(
parser: argparse.ArgumentParser,
)-> argparse.ArgumentParser:
"""Shared CLI arguments for CacheFlow servers."""
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--download-dir', type=str, default=None,
parser.add_argument('--model', type=str, default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument('--download-dir', type=str,
default=ServerArgs.download_dir,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument('--use-np-weights', action='store_true',
help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
help='save a numpy copy of model weights for faster '
'loading. This can increase the disk usage by up '
'to 2x.')
parser.add_argument('--use-dummy-weights', action='store_true',
help='use dummy values for model weights')
# TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'))
# Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
parser.add_argument('--use-ray', action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
default=ServerArgs.pipeline_parallel_size,
help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
default=ServerArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
parser.add_argument('--block-size', type=int, default=ServerArgs.block_size,
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=4, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-seqs', type=int, default=256, help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics')
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
help='random seed')
parser.add_argument('--swap-space', type=int, default=ServerArgs.swap_space,
help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float,
default=ServerArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for the '
'model executor')
parser.add_argument('--max-num-batched-tokens', type=int,
default=ServerArgs.max_num_batched_tokens,
help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-seqs', type=int,
default=ServerArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true',
help='disable logging statistics')
return parser
def create_server_configs_from_args(
args: argparse.Namespace,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Post-process the parsed arguments.
args.swap_space = args.swap_space * _GiB
args.max_num_seqs = min(args.max_num_seqs, args.max_num_batched_tokens)
# Initialize the configs.
model_config = ModelConfig(
args.model, args.download_dir, args.use_np_weights,
args.use_dummy_weights, args.dtype, args.seed)
cache_config = CacheConfig(args.block_size, args.gpu_memory_utilization,
args.swap_space)
parallel_config = ParallelConfig(args.pipeline_parallel_size,
args.tensor_parallel_size, args.use_ray)
scheduler_config = SchedulerConfig(args.max_num_batched_tokens,
args.max_num_seqs)
return model_config, cache_config, parallel_config, scheduler_config
def initialize_server_from_args(args: argparse.Namespace) -> LLMServer:
server_configs = create_server_configs_from_args(args)
parallel_config = server_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM server.
server = LLMServer(*server_configs, distributed_init_method, devices,
log_stats=not args.disable_log_stats)
return server

View File

@ -12,6 +12,8 @@ from cacheflow.core.scheduler import Scheduler
from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import initialize_cluster
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
from cacheflow.utils import Counter
@ -30,7 +32,7 @@ class LLMServer:
scheduler_config: SchedulerConfig,
distributed_init_method: str,
stage_devices: List[List[Any]],
log_stats: bool = True,
log_stats: bool,
) -> None:
logger.info(
"Initializing an LLM server with config: "
@ -90,7 +92,7 @@ class LLMServer:
get_all_outputs=True,
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space,
cpu_swap_space=self.cache_config.swap_space_bytes,
)
# Since we use a shared centralized controller, we take the minimum
@ -107,6 +109,18 @@ class LLMServer:
# Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config)
@classmethod
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
# Create the server configs.
server_configs = server_args.create_server_configs()
parallel_config = server_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM server.
server = cls(*server_configs, distributed_init_method, devices,
log_stats=not server_args.disable_log_stats)
return server
def add_request(
self,
request_id: str,

View File

@ -0,0 +1,23 @@
from cacheflow import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="facebook/opt-125m")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -1,13 +1,13 @@
import argparse
import uuid
from cacheflow import (add_server_arguments, initialize_server_from_args,
SamplingParams)
from cacheflow import ServerArgs, LLMServer, SamplingParams
def main(args: argparse.Namespace):
# Initialize the server.
server = initialize_server_from_args(args)
# Parse the CLI argument and initialize the server.
server_args = ServerArgs.from_cli_args(args)
server = LLMServer.from_server_args(server_args)
# Test the following prompts.
test_prompts = [
@ -39,6 +39,6 @@ def main(args: argparse.Namespace):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Simple CacheFlow server.')
parser = add_server_arguments(parser)
parser = ServerArgs.add_cli_args(parser)
args = parser.parse_args()
main(args)