Introduce LLM class for offline inference (#115)
This commit is contained in:
parent
f746ced08d
commit
655a5e48df
@ -1,19 +1,15 @@
|
||||
from cacheflow.entrypoints.llm import LLM
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import (
|
||||
add_server_arguments,
|
||||
create_server_configs_from_args,
|
||||
initialize_server_from_args,
|
||||
)
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.llm_server import LLMServer
|
||||
from cacheflow.server.ray_utils import initialize_cluster
|
||||
|
||||
__all__ = [
|
||||
"RequestOutput",
|
||||
"LLM",
|
||||
"SamplingParams",
|
||||
"RequestOutput",
|
||||
"LLMServer",
|
||||
"add_server_arguments",
|
||||
"create_server_configs_from_args",
|
||||
"initialize_server_from_args",
|
||||
"ServerArgs",
|
||||
"initialize_cluster",
|
||||
]
|
||||
|
@ -3,6 +3,8 @@ from typing import Optional
|
||||
import torch
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
_GiB = 1 << 30
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
||||
@ -70,7 +72,7 @@ class CacheConfig:
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.swap_space = swap_space
|
||||
self.swap_space_bytes = swap_space * _GiB
|
||||
|
||||
# Will be set after profiling.
|
||||
self.num_gpu_blocks = None
|
||||
@ -138,6 +140,8 @@ def _get_and_verify_dtype(
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
# Verify the dtype.
|
||||
|
@ -12,8 +12,7 @@ import uvicorn
|
||||
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import (
|
||||
add_server_arguments, create_server_configs_from_args)
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.llm_server import LLMServer
|
||||
from cacheflow.server.ray_utils import initialize_cluster
|
||||
|
||||
@ -116,10 +115,10 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=10002)
|
||||
parser = add_server_arguments(parser)
|
||||
parser = ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
server_configs = create_server_configs_from_args(args)
|
||||
server_configs = ServerArgs.from_cli_args(args).create_server_configs()
|
||||
parallel_config = server_configs[2]
|
||||
distributed_init_method, stage_devices = initialize_cluster(parallel_config)
|
||||
|
||||
|
62
cacheflow/entrypoints/llm.py
Normal file
62
cacheflow/entrypoints/llm.py
Normal file
@ -0,0 +1,62 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.llm_server import LLMServer
|
||||
from cacheflow.utils import Counter
|
||||
|
||||
|
||||
class LLM:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "default",
|
||||
seed: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
server_args = ServerArgs(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_server = LLMServer.from_server_args(server_args)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
use_tqdm: bool = True,
|
||||
) -> List[RequestOutput]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
pbar = tqdm(total=len(prompts), desc="Processed prompts")
|
||||
|
||||
# Add requests to the server.
|
||||
for prompt in prompts:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_server.add_request(request_id, prompt, sampling_params)
|
||||
|
||||
# Run the server.
|
||||
outputs: List[RequestOutput] = []
|
||||
while self.llm_server.has_unfinished_requests():
|
||||
step_outputs = self.llm_server.step()
|
||||
for output in step_outputs:
|
||||
if output.done:
|
||||
outputs.append(output)
|
||||
if use_tqdm:
|
||||
pbar.update(1)
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
return outputs
|
@ -35,7 +35,7 @@ class RequestOutput:
|
||||
prompt: str,
|
||||
prompt_token_ids: List[int],
|
||||
outputs: List[CompletionOutput],
|
||||
done: bool = False,
|
||||
done: bool,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
@ -43,8 +43,8 @@ class RequestOutput:
|
||||
self.outputs = outputs
|
||||
self.done = done
|
||||
|
||||
@staticmethod
|
||||
def from_seq_group(seq_group: SequenceGroup) -> "RequestOutput":
|
||||
@classmethod
|
||||
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
||||
# Get the top-n sequences.
|
||||
n = seq_group.sampling_params.n
|
||||
seqs = seq_group.get_seqs()
|
||||
@ -70,8 +70,8 @@ class RequestOutput:
|
||||
# Every sequence in the sequence group should have the same prompt.
|
||||
prompt = top_n_seqs[0].prompt
|
||||
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
|
||||
return RequestOutput(seq_group.request_id, prompt, prompt_token_ids,
|
||||
outputs, seq_group.is_finished())
|
||||
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
|
||||
seq_group.is_finished())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"RequestOutput(request_id={self.request_id}, "
|
||||
|
@ -1,74 +1,117 @@
|
||||
import argparse
|
||||
from typing import Tuple
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from cacheflow.server.llm_server import LLMServer
|
||||
from cacheflow.server.ray_utils import initialize_cluster
|
||||
|
||||
_GiB = 1 << 30
|
||||
|
||||
|
||||
def add_server_arguments(parser: argparse.ArgumentParser):
|
||||
"""Shared arguments for CacheFlow servers."""
|
||||
@dataclass
|
||||
class ServerArgs:
|
||||
model: str
|
||||
download_dir: Optional[str] = None
|
||||
use_np_weights: bool = False
|
||||
use_dummy_weights: bool = False
|
||||
dtype: str = "default"
|
||||
seed: int = 0
|
||||
use_ray: bool = False
|
||||
pipeline_parallel_size: int = 1
|
||||
tensor_parallel_size: int = 1
|
||||
block_size: int = 16
|
||||
swap_space: int = 4 # GiB
|
||||
gpu_memory_utilization: float = 0.95
|
||||
max_num_batched_tokens: int = 2560
|
||||
max_num_seqs: int = 256
|
||||
disable_log_stats: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
) -> argparse.ArgumentParser:
|
||||
return _add_server_arguments(parser)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
|
||||
# Get the list of attributes of this dataclass.
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
# Set the attributes from the parsed arguments.
|
||||
server_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
return server_args
|
||||
|
||||
def create_server_configs(
|
||||
self,
|
||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||
# Initialize the configs.
|
||||
model_config = ModelConfig(
|
||||
self.model, self.download_dir, self.use_np_weights,
|
||||
self.use_dummy_weights, self.dtype, self.seed)
|
||||
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
|
||||
self.swap_space)
|
||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
self.use_ray)
|
||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||
self.max_num_seqs)
|
||||
return model_config, cache_config, parallel_config, scheduler_config
|
||||
|
||||
|
||||
def _add_server_arguments(
|
||||
parser: argparse.ArgumentParser,
|
||||
)-> argparse.ArgumentParser:
|
||||
"""Shared CLI arguments for CacheFlow servers."""
|
||||
# Model arguments
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
|
||||
parser.add_argument('--download-dir', type=str, default=None,
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m',
|
||||
help='name or path of the huggingface model to use')
|
||||
parser.add_argument('--download-dir', type=str,
|
||||
default=ServerArgs.download_dir,
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of huggingface')
|
||||
parser.add_argument('--use-np-weights', action='store_true',
|
||||
help='save a numpy copy of model weights for faster loading')
|
||||
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
|
||||
help='save a numpy copy of model weights for faster '
|
||||
'loading. This can increase the disk usage by up '
|
||||
'to 2x.')
|
||||
parser.add_argument('--use-dummy-weights', action='store_true',
|
||||
help='use dummy values for model weights')
|
||||
# TODO(woosuk): Support FP32.
|
||||
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
|
||||
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
|
||||
choices=['default', 'half', 'bfloat16'],
|
||||
help=('data type for model weights and activations. '
|
||||
'The "default" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.'))
|
||||
# Parallel arguments
|
||||
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
|
||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
|
||||
parser.add_argument('--use-ray', action='store_true',
|
||||
help='use Ray for distributed serving, will be '
|
||||
'automatically set when using more than 1 GPU')
|
||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
|
||||
default=ServerArgs.pipeline_parallel_size,
|
||||
help='number of pipeline stages')
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
|
||||
default=ServerArgs.tensor_parallel_size,
|
||||
help='number of tensor parallel replicas')
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
|
||||
parser.add_argument('--block-size', type=int, default=ServerArgs.block_size,
|
||||
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
help='token block size')
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||
parser.add_argument('--swap-space', type=int, default=4, help='CPU swap space size (GiB) per GPU')
|
||||
parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor')
|
||||
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
|
||||
parser.add_argument('--max-num-seqs', type=int, default=256, help='maximum number of sequences per iteration')
|
||||
parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics')
|
||||
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
|
||||
help='random seed')
|
||||
parser.add_argument('--swap-space', type=int, default=ServerArgs.swap_space,
|
||||
help='CPU swap space size (GiB) per GPU')
|
||||
parser.add_argument('--gpu-memory-utilization', type=float,
|
||||
default=ServerArgs.gpu_memory_utilization,
|
||||
help='the percentage of GPU memory to be used for the '
|
||||
'model executor')
|
||||
parser.add_argument('--max-num-batched-tokens', type=int,
|
||||
default=ServerArgs.max_num_batched_tokens,
|
||||
help='maximum number of batched tokens per iteration')
|
||||
parser.add_argument('--max-num-seqs', type=int,
|
||||
default=ServerArgs.max_num_seqs,
|
||||
help='maximum number of sequences per iteration')
|
||||
parser.add_argument('--disable-log-stats', action='store_true',
|
||||
help='disable logging statistics')
|
||||
return parser
|
||||
|
||||
|
||||
def create_server_configs_from_args(
|
||||
args: argparse.Namespace,
|
||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||
# Post-process the parsed arguments.
|
||||
args.swap_space = args.swap_space * _GiB
|
||||
args.max_num_seqs = min(args.max_num_seqs, args.max_num_batched_tokens)
|
||||
|
||||
# Initialize the configs.
|
||||
model_config = ModelConfig(
|
||||
args.model, args.download_dir, args.use_np_weights,
|
||||
args.use_dummy_weights, args.dtype, args.seed)
|
||||
cache_config = CacheConfig(args.block_size, args.gpu_memory_utilization,
|
||||
args.swap_space)
|
||||
parallel_config = ParallelConfig(args.pipeline_parallel_size,
|
||||
args.tensor_parallel_size, args.use_ray)
|
||||
scheduler_config = SchedulerConfig(args.max_num_batched_tokens,
|
||||
args.max_num_seqs)
|
||||
return model_config, cache_config, parallel_config, scheduler_config
|
||||
|
||||
|
||||
def initialize_server_from_args(args: argparse.Namespace) -> LLMServer:
|
||||
server_configs = create_server_configs_from_args(args)
|
||||
parallel_config = server_configs[2]
|
||||
|
||||
# Initialize the cluster.
|
||||
distributed_init_method, devices = initialize_cluster(parallel_config)
|
||||
|
||||
# Create the LLM server.
|
||||
server = LLMServer(*server_configs, distributed_init_method, devices,
|
||||
log_stats=not args.disable_log_stats)
|
||||
return server
|
||||
|
@ -12,6 +12,8 @@ from cacheflow.core.scheduler import Scheduler
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.ray_utils import initialize_cluster
|
||||
from cacheflow.server.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from cacheflow.utils import Counter
|
||||
@ -30,7 +32,7 @@ class LLMServer:
|
||||
scheduler_config: SchedulerConfig,
|
||||
distributed_init_method: str,
|
||||
stage_devices: List[List[Any]],
|
||||
log_stats: bool = True,
|
||||
log_stats: bool,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Initializing an LLM server with config: "
|
||||
@ -90,7 +92,7 @@ class LLMServer:
|
||||
get_all_outputs=True,
|
||||
block_size=self.cache_config.block_size,
|
||||
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
||||
cpu_swap_space=self.cache_config.swap_space,
|
||||
cpu_swap_space=self.cache_config.swap_space_bytes,
|
||||
)
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
@ -107,6 +109,18 @@ class LLMServer:
|
||||
# Initialize the cache.
|
||||
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
||||
|
||||
@classmethod
|
||||
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
|
||||
# Create the server configs.
|
||||
server_configs = server_args.create_server_configs()
|
||||
parallel_config = server_configs[2]
|
||||
# Initialize the cluster.
|
||||
distributed_init_method, devices = initialize_cluster(parallel_config)
|
||||
# Create the LLM server.
|
||||
server = cls(*server_configs, distributed_init_method, devices,
|
||||
log_stats=not server_args.disable_log_stats)
|
||||
return server
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
|
23
examples/offline_inference.py
Normal file
23
examples/offline_inference.py
Normal file
@ -0,0 +1,23 @@
|
||||
from cacheflow import LLM, SamplingParams
|
||||
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
@ -1,13 +1,13 @@
|
||||
import argparse
|
||||
import uuid
|
||||
|
||||
from cacheflow import (add_server_arguments, initialize_server_from_args,
|
||||
SamplingParams)
|
||||
from cacheflow import ServerArgs, LLMServer, SamplingParams
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
# Initialize the server.
|
||||
server = initialize_server_from_args(args)
|
||||
# Parse the CLI argument and initialize the server.
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
server = LLMServer.from_server_args(server_args)
|
||||
|
||||
# Test the following prompts.
|
||||
test_prompts = [
|
||||
@ -39,6 +39,6 @@ def main(args: argparse.Namespace):
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Simple CacheFlow server.')
|
||||
parser = add_server_arguments(parser)
|
||||
parser = ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user