Introduce LLM class for offline inference (#115)

This commit is contained in:
Woosuk Kwon 2023-05-21 17:04:18 -07:00 committed by GitHub
parent f746ced08d
commit 655a5e48df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 222 additions and 81 deletions

View File

@ -1,19 +1,15 @@
from cacheflow.entrypoints.llm import LLM
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ( from cacheflow.server.arg_utils import ServerArgs
add_server_arguments,
create_server_configs_from_args,
initialize_server_from_args,
)
from cacheflow.server.llm_server import LLMServer from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster from cacheflow.server.ray_utils import initialize_cluster
__all__ = [ __all__ = [
"RequestOutput", "LLM",
"SamplingParams", "SamplingParams",
"RequestOutput",
"LLMServer", "LLMServer",
"add_server_arguments", "ServerArgs",
"create_server_configs_from_args",
"initialize_server_from_args",
"initialize_cluster", "initialize_cluster",
] ]

View File

@ -3,6 +3,8 @@ from typing import Optional
import torch import torch
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
_GiB = 1 << 30
class ModelConfig: class ModelConfig:
@ -70,7 +72,7 @@ class CacheConfig:
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space = swap_space self.swap_space_bytes = swap_space * _GiB
# Will be set after profiling. # Will be set after profiling.
self.num_gpu_blocks = None self.num_gpu_blocks = None
@ -138,6 +140,8 @@ def _get_and_verify_dtype(
else: else:
torch_dtype = config_dtype torch_dtype = config_dtype
else: else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
# Verify the dtype. # Verify the dtype.

View File

@ -12,8 +12,7 @@ import uvicorn
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ( from cacheflow.server.arg_utils import ServerArgs
add_server_arguments, create_server_configs_from_args)
from cacheflow.server.llm_server import LLMServer from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster from cacheflow.server.ray_utils import initialize_cluster
@ -116,10 +115,10 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=10002) parser.add_argument("--port", type=int, default=10002)
parser = add_server_arguments(parser) parser = ServerArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
server_configs = create_server_configs_from_args(args) server_configs = ServerArgs.from_cli_args(args).create_server_configs()
parallel_config = server_configs[2] parallel_config = server_configs[2]
distributed_init_method, stage_devices = initialize_cluster(parallel_config) distributed_init_method, stage_devices = initialize_cluster(parallel_config)

View File

@ -0,0 +1,62 @@
from typing import List, Optional
from tqdm import tqdm
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer
from cacheflow.utils import Counter
class LLM:
def __init__(
self,
model: str,
tensor_parallel_size: int = 1,
dtype: str = "default",
seed: int = 0,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
server_args = ServerArgs(
model=model,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
seed=seed,
**kwargs,
)
self.llm_server = LLMServer.from_server_args(server_args)
self.request_counter = Counter()
def generate(
self,
prompts: List[str],
sampling_params: Optional[SamplingParams] = None,
use_tqdm: bool = True,
) -> List[RequestOutput]:
if sampling_params is None:
sampling_params = SamplingParams()
# Initialize tqdm.
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Processed prompts")
# Add requests to the server.
for prompt in prompts:
request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params)
# Run the server.
outputs: List[RequestOutput] = []
while self.llm_server.has_unfinished_requests():
step_outputs = self.llm_server.step()
for output in step_outputs:
if output.done:
outputs.append(output)
if use_tqdm:
pbar.update(1)
if use_tqdm:
pbar.close()
return outputs

View File

@ -35,7 +35,7 @@ class RequestOutput:
prompt: str, prompt: str,
prompt_token_ids: List[int], prompt_token_ids: List[int],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
done: bool = False, done: bool,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
@ -43,8 +43,8 @@ class RequestOutput:
self.outputs = outputs self.outputs = outputs
self.done = done self.done = done
@staticmethod @classmethod
def from_seq_group(seq_group: SequenceGroup) -> "RequestOutput": def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
# Get the top-n sequences. # Get the top-n sequences.
n = seq_group.sampling_params.n n = seq_group.sampling_params.n
seqs = seq_group.get_seqs() seqs = seq_group.get_seqs()
@ -70,8 +70,8 @@ class RequestOutput:
# Every sequence in the sequence group should have the same prompt. # Every sequence in the sequence group should have the same prompt.
prompt = top_n_seqs[0].prompt prompt = top_n_seqs[0].prompt
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
return RequestOutput(seq_group.request_id, prompt, prompt_token_ids, return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
outputs, seq_group.is_finished()) seq_group.is_finished())
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, " return (f"RequestOutput(request_id={self.request_id}, "

View File

@ -1,74 +1,117 @@
import argparse import argparse
from typing import Tuple import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig, from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster
_GiB = 1 << 30
def add_server_arguments(parser: argparse.ArgumentParser): @dataclass
"""Shared arguments for CacheFlow servers.""" class ServerArgs:
model: str
download_dir: Optional[str] = None
use_np_weights: bool = False
use_dummy_weights: bool = False
dtype: str = "default"
seed: int = 0
use_ray: bool = False
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
block_size: int = 16
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.95
max_num_batched_tokens: int = 2560
max_num_seqs: int = 256
disable_log_stats: bool = False
def __post_init__(self):
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
return _add_server_arguments(parser)
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
server_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return server_args
def create_server_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs.
model_config = ModelConfig(
self.model, self.download_dir, self.use_np_weights,
self.use_dummy_weights, self.dtype, self.seed)
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
self.swap_space)
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.use_ray)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs)
return model_config, cache_config, parallel_config, scheduler_config
def _add_server_arguments(
parser: argparse.ArgumentParser,
)-> argparse.ArgumentParser:
"""Shared CLI arguments for CacheFlow servers."""
# Model arguments # Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name') parser.add_argument('--model', type=str, default='facebook/opt-125m',
parser.add_argument('--download-dir', type=str, default=None, help='name or path of the huggingface model to use')
parser.add_argument('--download-dir', type=str,
default=ServerArgs.download_dir,
help='directory to download and load the weights, ' help='directory to download and load the weights, '
'default to the default cache dir of huggingface') 'default to the default cache dir of huggingface')
parser.add_argument('--use-np-weights', action='store_true', parser.add_argument('--use-np-weights', action='store_true',
help='save a numpy copy of model weights for faster loading') help='save a numpy copy of model weights for faster '
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights') 'loading. This can increase the disk usage by up '
'to 2x.')
parser.add_argument('--use-dummy-weights', action='store_true',
help='use dummy values for model weights')
# TODO(woosuk): Support FP32. # TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'], parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. ' help=('data type for model weights and activations. '
'The "default" option will use FP16 precision ' 'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision ' 'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')) 'for BF16 models.'))
# Parallel arguments # Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU') parser.add_argument('--use-ray', action='store_true',
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') help='use Ray for distributed serving, will be '
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas') 'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
default=ServerArgs.pipeline_parallel_size,
help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
default=ServerArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size') parser.add_argument('--block-size', type=int, default=ServerArgs.block_size,
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request). # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--seed', type=int, default=ServerArgs.seed,
parser.add_argument('--swap-space', type=int, default=4, help='CPU swap space size (GiB) per GPU') help='random seed')
parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor') parser.add_argument('--swap-space', type=int, default=ServerArgs.swap_space,
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration') help='CPU swap space size (GiB) per GPU')
parser.add_argument('--max-num-seqs', type=int, default=256, help='maximum number of sequences per iteration') parser.add_argument('--gpu-memory-utilization', type=float,
parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') default=ServerArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for the '
'model executor')
parser.add_argument('--max-num-batched-tokens', type=int,
default=ServerArgs.max_num_batched_tokens,
help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-seqs', type=int,
default=ServerArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true',
help='disable logging statistics')
return parser return parser
def create_server_configs_from_args(
args: argparse.Namespace,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Post-process the parsed arguments.
args.swap_space = args.swap_space * _GiB
args.max_num_seqs = min(args.max_num_seqs, args.max_num_batched_tokens)
# Initialize the configs.
model_config = ModelConfig(
args.model, args.download_dir, args.use_np_weights,
args.use_dummy_weights, args.dtype, args.seed)
cache_config = CacheConfig(args.block_size, args.gpu_memory_utilization,
args.swap_space)
parallel_config = ParallelConfig(args.pipeline_parallel_size,
args.tensor_parallel_size, args.use_ray)
scheduler_config = SchedulerConfig(args.max_num_batched_tokens,
args.max_num_seqs)
return model_config, cache_config, parallel_config, scheduler_config
def initialize_server_from_args(args: argparse.Namespace) -> LLMServer:
server_configs = create_server_configs_from_args(args)
parallel_config = server_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM server.
server = LLMServer(*server_configs, distributed_init_method, devices,
log_stats=not args.disable_log_stats)
return server

View File

@ -12,6 +12,8 @@ from cacheflow.core.scheduler import Scheduler
from cacheflow.logger import init_logger from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import initialize_cluster
from cacheflow.server.tokenizer_utils import get_tokenizer from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
from cacheflow.utils import Counter from cacheflow.utils import Counter
@ -30,7 +32,7 @@ class LLMServer:
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
distributed_init_method: str, distributed_init_method: str,
stage_devices: List[List[Any]], stage_devices: List[List[Any]],
log_stats: bool = True, log_stats: bool,
) -> None: ) -> None:
logger.info( logger.info(
"Initializing an LLM server with config: " "Initializing an LLM server with config: "
@ -90,7 +92,7 @@ class LLMServer:
get_all_outputs=True, get_all_outputs=True,
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization, gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space, cpu_swap_space=self.cache_config.swap_space_bytes,
) )
# Since we use a shared centralized controller, we take the minimum # Since we use a shared centralized controller, we take the minimum
@ -107,6 +109,18 @@ class LLMServer:
# Initialize the cache. # Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config) self._run_workers("init_cache_engine", cache_config=self.cache_config)
@classmethod
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
# Create the server configs.
server_configs = server_args.create_server_configs()
parallel_config = server_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM server.
server = cls(*server_configs, distributed_init_method, devices,
log_stats=not server_args.disable_log_stats)
return server
def add_request( def add_request(
self, self,
request_id: str, request_id: str,

View File

@ -0,0 +1,23 @@
from cacheflow import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="facebook/opt-125m")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -1,13 +1,13 @@
import argparse import argparse
import uuid import uuid
from cacheflow import (add_server_arguments, initialize_server_from_args, from cacheflow import ServerArgs, LLMServer, SamplingParams
SamplingParams)
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
# Initialize the server. # Parse the CLI argument and initialize the server.
server = initialize_server_from_args(args) server_args = ServerArgs.from_cli_args(args)
server = LLMServer.from_server_args(server_args)
# Test the following prompts. # Test the following prompts.
test_prompts = [ test_prompts = [
@ -39,6 +39,6 @@ def main(args: argparse.Namespace):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Simple CacheFlow server.') parser = argparse.ArgumentParser(description='Simple CacheFlow server.')
parser = add_server_arguments(parser) parser = ServerArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)