Refactor system architecture (#109)
This commit is contained in:
parent
7297fa6f7c
commit
c3442c1f6f
13
README.md
13
README.md
@ -10,13 +10,17 @@ pip install -e . # This may take several minutes.
|
|||||||
## Test simple server
|
## Test simple server
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Single-GPU inference.
|
||||||
|
python examples/simple_server.py # --model <your_model>
|
||||||
|
|
||||||
|
# Multi-GPU inference (e.g., 2 GPUs).
|
||||||
ray start --head
|
ray start --head
|
||||||
python simple_server.py
|
python examples/simple_server.py -tp 2 # --model <your_model>
|
||||||
```
|
```
|
||||||
|
|
||||||
The detailed arguments for `simple_server.py` can be found by:
|
The detailed arguments for `simple_server.py` can be found by:
|
||||||
```bash
|
```bash
|
||||||
python simple_server.py --help
|
python examples/simple_server.py --help
|
||||||
```
|
```
|
||||||
|
|
||||||
## FastAPI server
|
## FastAPI server
|
||||||
@ -24,12 +28,12 @@ python simple_server.py --help
|
|||||||
To start the server:
|
To start the server:
|
||||||
```bash
|
```bash
|
||||||
ray start --head
|
ray start --head
|
||||||
python -m cacheflow.http_frontend.fastapi_frontend
|
python -m cacheflow.entrypoints.fastapi_server # --model <your_model>
|
||||||
```
|
```
|
||||||
|
|
||||||
To test the server:
|
To test the server:
|
||||||
```bash
|
```bash
|
||||||
python -m cacheflow.http_frontend.test_cli_client
|
python test_cli_client.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Gradio web server
|
## Gradio web server
|
||||||
@ -55,7 +59,6 @@ Since LLaMA weight is not fully public, we cannot directly download the LLaMA we
|
|||||||
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
||||||
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path/llama-7b
|
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path/llama-7b
|
||||||
```
|
```
|
||||||
Please make sure that `llama` is included in the output directory name.
|
|
||||||
2. For all the commands above, specify the model with `--model /output/path/llama-7b` to load the model. For example:
|
2. For all the commands above, specify the model with `--model /output/path/llama-7b` to load the model. For example:
|
||||||
```bash
|
```bash
|
||||||
python simple_server.py --model /output/path/llama-7b
|
python simple_server.py --model /output/path/llama-7b
|
||||||
|
19
cacheflow/__init__.py
Normal file
19
cacheflow/__init__.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from cacheflow.outputs import RequestOutput
|
||||||
|
from cacheflow.sampling_params import SamplingParams
|
||||||
|
from cacheflow.server.arg_utils import (
|
||||||
|
add_server_arguments,
|
||||||
|
create_server_configs_from_args,
|
||||||
|
initialize_server_from_args,
|
||||||
|
)
|
||||||
|
from cacheflow.server.llm_server import LLMServer
|
||||||
|
from cacheflow.server.ray_utils import initialize_cluster
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RequestOutput",
|
||||||
|
"SamplingParams",
|
||||||
|
"LLMServer",
|
||||||
|
"add_server_arguments",
|
||||||
|
"create_server_configs_from_args",
|
||||||
|
"initialize_server_from_args",
|
||||||
|
"initialize_cluster",
|
||||||
|
]
|
165
cacheflow/config.py
Normal file
165
cacheflow/config.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
download_dir: Optional[str],
|
||||||
|
use_np_weights: bool,
|
||||||
|
use_dummy_weights: bool,
|
||||||
|
dtype: str,
|
||||||
|
seed: int,
|
||||||
|
) -> None:
|
||||||
|
self.model = model
|
||||||
|
self.download_dir = download_dir
|
||||||
|
self.use_np_weights = use_np_weights
|
||||||
|
self.use_dummy_weights = use_dummy_weights
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
self.hf_config: PretrainedConfig = AutoConfig.from_pretrained(model)
|
||||||
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||||
|
|
||||||
|
def verify_with_parallel_config(
|
||||||
|
self,
|
||||||
|
parallel_config: "ParallelConfig",
|
||||||
|
) -> None:
|
||||||
|
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||||
|
tensor_parallel_size = parallel_config.tensor_parallel_size
|
||||||
|
if total_num_attention_heads % tensor_parallel_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Total number of attention heads ({total_num_attention_heads})"
|
||||||
|
" must be divisible by tensor parallel size "
|
||||||
|
f"({tensor_parallel_size}).")
|
||||||
|
|
||||||
|
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||||
|
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
||||||
|
if total_num_hidden_layers % pipeline_parallel_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Total number of hidden layers ({total_num_hidden_layers}) "
|
||||||
|
"must be divisible by pipeline parallel size "
|
||||||
|
f"({pipeline_parallel_size}).")
|
||||||
|
|
||||||
|
def get_hidden_size(self) -> int:
|
||||||
|
return self.hf_config.hidden_size
|
||||||
|
|
||||||
|
def get_head_size(self) -> int:
|
||||||
|
# FIXME(woosuk): This may not be true for all models.
|
||||||
|
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||||
|
|
||||||
|
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||||
|
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||||
|
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
|
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||||
|
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||||
|
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||||
|
|
||||||
|
|
||||||
|
class CacheConfig:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
block_size: int,
|
||||||
|
gpu_memory_utilization: float,
|
||||||
|
swap_space: int,
|
||||||
|
) -> None:
|
||||||
|
self.block_size = block_size
|
||||||
|
self.gpu_memory_utilization = gpu_memory_utilization
|
||||||
|
self.swap_space = swap_space
|
||||||
|
|
||||||
|
# Will be set after profiling.
|
||||||
|
self.num_gpu_blocks = None
|
||||||
|
self.num_cpu_blocks = None
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelConfig:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pipeline_parallel_size: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
use_ray: bool,
|
||||||
|
) -> None:
|
||||||
|
self.pipeline_parallel_size = pipeline_parallel_size
|
||||||
|
self.tensor_parallel_size = tensor_parallel_size
|
||||||
|
self.use_ray = use_ray
|
||||||
|
|
||||||
|
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
||||||
|
if self.world_size > 1:
|
||||||
|
self.use_ray = True
|
||||||
|
self._verify_args()
|
||||||
|
|
||||||
|
def _verify_args(self) -> None:
|
||||||
|
if self.pipeline_parallel_size > 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Pipeline parallelism is not supported yet.")
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerConfig:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
|
max_num_seqs: int,
|
||||||
|
) -> None:
|
||||||
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
|
self.max_num_seqs = max_num_seqs
|
||||||
|
|
||||||
|
|
||||||
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
|
"half": torch.float16,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"float": torch.float32,
|
||||||
|
"float32": torch.float32,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_and_verify_dtype(
|
||||||
|
config: PretrainedConfig,
|
||||||
|
dtype: str,
|
||||||
|
) -> torch.dtype:
|
||||||
|
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||||
|
# because config.torch_dtype can be None.
|
||||||
|
config_dtype = getattr(config, "torch_dtype", None)
|
||||||
|
if config_dtype is None:
|
||||||
|
config_dtype = torch.float32
|
||||||
|
|
||||||
|
dtype = dtype.lower()
|
||||||
|
if dtype == "default":
|
||||||
|
if config_dtype == torch.float32:
|
||||||
|
# Following the common practice, we use float16 for float32 models.
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
else:
|
||||||
|
torch_dtype = config_dtype
|
||||||
|
else:
|
||||||
|
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||||
|
|
||||||
|
# Verify the dtype.
|
||||||
|
if torch_dtype != config_dtype:
|
||||||
|
if torch_dtype == torch.float32:
|
||||||
|
# Upcasting to float32 is allowed.
|
||||||
|
pass
|
||||||
|
elif config_dtype == torch.float32:
|
||||||
|
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Casting between float16 and bfloat16 is not allowed.
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot use {torch_dtype} for {config_dtype} model.")
|
||||||
|
|
||||||
|
# Check if the GPU supports the dtype.
|
||||||
|
if torch_dtype == torch.bfloat16:
|
||||||
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
|
if compute_capability[0] < 8:
|
||||||
|
gpu_name = torch.cuda.get_device_name()
|
||||||
|
raise ValueError(
|
||||||
|
"Bfloat16 is only supported on GPUs with compute capability "
|
||||||
|
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
||||||
|
f"{compute_capability[0]}.{compute_capability[1]}.")
|
||||||
|
return torch_dtype
|
@ -2,10 +2,10 @@ import enum
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from cacheflow.config import CacheConfig, SchedulerConfig
|
||||||
from cacheflow.core.block_manager import BlockSpaceManager
|
from cacheflow.core.block_manager import BlockSpaceManager
|
||||||
from cacheflow.core.policy import PolicyFactory
|
from cacheflow.core.policy import PolicyFactory
|
||||||
from cacheflow.logger import init_logger
|
from cacheflow.logger import init_logger
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
|
from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||||
SequenceGroupMetadata, SequenceOutputs,
|
SequenceGroupMetadata, SequenceOutputs,
|
||||||
SequenceStatus)
|
SequenceStatus)
|
||||||
@ -28,43 +28,53 @@ class PreemptionMode(enum.Enum):
|
|||||||
RECOMPUTE = enum.auto()
|
RECOMPUTE = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerOutputs:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
|
) -> None:
|
||||||
|
self.blocks_to_swap_in = blocks_to_swap_in
|
||||||
|
self.blocks_to_swap_out = blocks_to_swap_out
|
||||||
|
self.blocks_to_copy = blocks_to_copy
|
||||||
|
# Swap in and swap out should never happen at the same time.
|
||||||
|
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
||||||
|
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return (not self.blocks_to_swap_in
|
||||||
|
and not self.blocks_to_swap_out
|
||||||
|
and not self.blocks_to_copy)
|
||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
controllers: List,
|
scheduler_config: SchedulerConfig,
|
||||||
block_size: int,
|
cache_config: CacheConfig,
|
||||||
num_gpu_blocks: int,
|
|
||||||
num_cpu_blocks: int,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
max_num_sequences: int,
|
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.controllers = controllers
|
self.scheduler_config = scheduler_config
|
||||||
self.block_size = block_size
|
self.cache_config = cache_config
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
|
||||||
self.num_cpu_blocks = num_cpu_blocks
|
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
|
||||||
self.max_num_sequences = max_num_sequences
|
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
|
|
||||||
# Instantiate the scheduling policy.
|
# Instantiate the scheduling policy.
|
||||||
self.policy = PolicyFactory.get_policy(policy_name='fcfs')
|
self.policy = PolicyFactory.get_policy(policy_name='fcfs')
|
||||||
# Create the block space manager.
|
# Create the block space manager.
|
||||||
self.block_manager = BlockSpaceManager(
|
self.block_manager = BlockSpaceManager(
|
||||||
block_size=block_size,
|
block_size=self.cache_config.block_size,
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
||||||
num_cpu_blocks=num_cpu_blocks,
|
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sequence groups in the WAITING state.
|
# Sequence groups in the WAITING state.
|
||||||
self.waiting: List[SequenceGroup] = []
|
self.waiting: List[SequenceGroup] = []
|
||||||
# Sequence groups in the RUNNING state.
|
# Sequence groups in the RUNNING state.
|
||||||
self.running: List[SequenceGroup] = []
|
self.running: List[SequenceGroup] = []
|
||||||
# Mapping: group_id -> num_steps.
|
# Mapping: request_id -> num_steps.
|
||||||
self.num_steps: Dict[int, int] = {}
|
self.num_steps: Dict[str, int] = {}
|
||||||
# Mapping: group_id -> sampling params.
|
|
||||||
self.sampling_params: Dict[int, SamplingParams] = {}
|
|
||||||
# Sequence groups in the SWAPPED state.
|
# Sequence groups in the SWAPPED state.
|
||||||
self.swapped: List[SequenceGroup] = []
|
self.swapped: List[SequenceGroup] = []
|
||||||
|
|
||||||
@ -72,18 +82,15 @@ class Scheduler:
|
|||||||
# List[timestamp, num_tokens]
|
# List[timestamp, num_tokens]
|
||||||
self.num_input_tokens: List[Tuple[float, int]] = []
|
self.num_input_tokens: List[Tuple[float, int]] = []
|
||||||
|
|
||||||
def add_sequence_groups(
|
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||||
self,
|
|
||||||
seq_groups: List[Tuple[SequenceGroup, SamplingParams]],
|
|
||||||
) -> None:
|
|
||||||
# Add sequence groups to the waiting queue.
|
# Add sequence groups to the waiting queue.
|
||||||
for seq_group, sampling_params in seq_groups:
|
assert seq_group.request_id not in self.num_steps
|
||||||
self.waiting.append(seq_group)
|
self.waiting.append(seq_group)
|
||||||
self.sampling_params[seq_group.group_id] = sampling_params
|
|
||||||
|
|
||||||
def _schedule(
|
def has_unfinished_seqs(self) -> bool:
|
||||||
self,
|
return self.waiting or self.running or self.swapped
|
||||||
) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, List[int]], List[int]]:
|
|
||||||
|
def _schedule(self) -> Tuple[SchedulerOutputs, List[int]]:
|
||||||
# Blocks that need to be swaped or copied before model execution.
|
# Blocks that need to be swaped or copied before model execution.
|
||||||
blocks_to_swap_in: Dict[int, int] = {}
|
blocks_to_swap_in: Dict[int, int] = {}
|
||||||
blocks_to_swap_out: Dict[int, int] = {}
|
blocks_to_swap_out: Dict[int, int] = {}
|
||||||
@ -136,8 +143,9 @@ class Scheduler:
|
|||||||
|
|
||||||
# The total number of sequences in the RUNNING state should not
|
# The total number of sequences in the RUNNING state should not
|
||||||
# exceed the maximum number of sequences.
|
# exceed the maximum number of sequences.
|
||||||
num_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
||||||
if len(self.running) + num_seqs > self.max_num_sequences:
|
num_curr_seqs = len(self.running)
|
||||||
|
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
|
||||||
break
|
break
|
||||||
|
|
||||||
seq_group = self.swapped.pop(0)
|
seq_group = self.swapped.pop(0)
|
||||||
@ -151,7 +159,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Join waiting sequences if possible.
|
# Join waiting sequences if possible.
|
||||||
prompt_group_ids: List[int] = []
|
prompt_group_ids: List[str] = []
|
||||||
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
|
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
|
||||||
# prioritized over the sequence groups in the WAITING state.
|
# prioritized over the sequence groups in the WAITING state.
|
||||||
# This is because we want to bound the amount of CPU memory taken by
|
# This is because we want to bound the amount of CPU memory taken by
|
||||||
@ -172,25 +180,31 @@ class Scheduler:
|
|||||||
# If the number of batched tokens exceeds the limit, stop.
|
# If the number of batched tokens exceeds the limit, stop.
|
||||||
num_prompt_tokens = seq_group.seqs[0].get_len()
|
num_prompt_tokens = seq_group.seqs[0].get_len()
|
||||||
if (num_batched_tokens + num_prompt_tokens
|
if (num_batched_tokens + num_prompt_tokens
|
||||||
> self.max_num_batched_tokens):
|
> self.scheduler_config.max_num_batched_tokens):
|
||||||
break
|
break
|
||||||
|
|
||||||
# The total number of sequences in the RUNNING state should not
|
# The total number of sequences in the RUNNING state should not
|
||||||
# exceed the maximum number of sequences.
|
# exceed the maximum number of sequences.
|
||||||
num_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
|
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
|
||||||
if len(self.running) + num_seqs > self.max_num_sequences:
|
num_curr_seqs = len(self.running)
|
||||||
|
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
|
||||||
break
|
break
|
||||||
|
|
||||||
seq_group = self.waiting.pop(0)
|
seq_group = self.waiting.pop(0)
|
||||||
self._allocate(seq_group)
|
self._allocate(seq_group)
|
||||||
self.running.append(seq_group)
|
self.running.append(seq_group)
|
||||||
num_batched_tokens += num_prompt_tokens
|
num_batched_tokens += num_prompt_tokens
|
||||||
prompt_group_ids.append(seq_group.group_id)
|
prompt_group_ids.append(seq_group.request_id)
|
||||||
|
|
||||||
|
scheduler_outputs = SchedulerOutputs(
|
||||||
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
)
|
||||||
if not self.log_stats:
|
if not self.log_stats:
|
||||||
return (blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy,
|
return scheduler_outputs, prompt_group_ids
|
||||||
prompt_group_ids)
|
|
||||||
|
|
||||||
|
# TODO(woosuk): Move the below code to server.
|
||||||
now = time.time()
|
now = time.time()
|
||||||
if num_batched_tokens > 0:
|
if num_batched_tokens > 0:
|
||||||
self.num_input_tokens.append((now, num_batched_tokens))
|
self.num_input_tokens.append((now, num_batched_tokens))
|
||||||
@ -208,13 +222,16 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
avg_throughput = 0.0
|
avg_throughput = 0.0
|
||||||
|
|
||||||
|
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||||
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
|
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
|
||||||
num_used_gpu_blocks = self.num_gpu_blocks - num_free_gpu_blocks
|
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
|
||||||
gpu_cache_usage = num_used_gpu_blocks / self.num_gpu_blocks
|
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
|
||||||
if self.num_cpu_blocks > 0:
|
|
||||||
|
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
|
||||||
|
if total_num_cpu_blocks > 0:
|
||||||
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
|
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
|
||||||
num_used_cpu_blocks = self.num_cpu_blocks - num_free_cpu_blocks
|
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
|
||||||
cpu_cache_usage = num_used_cpu_blocks / self.num_cpu_blocks
|
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
|
||||||
else:
|
else:
|
||||||
cpu_cache_usage = 0.0
|
cpu_cache_usage = 0.0
|
||||||
|
|
||||||
@ -225,27 +242,18 @@ class Scheduler:
|
|||||||
f"Pending: {len(self.waiting)} reqs, "
|
f"Pending: {len(self.waiting)} reqs, "
|
||||||
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
||||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||||
|
return scheduler_outputs, prompt_group_ids
|
||||||
|
|
||||||
return (blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy,
|
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
|
||||||
prompt_group_ids)
|
|
||||||
|
|
||||||
def step(self) -> List[SequenceGroup]:
|
|
||||||
# Schedule sequence groups.
|
# Schedule sequence groups.
|
||||||
# This function call changes the internal states of the scheduler
|
# This function call changes the internal states of the scheduler
|
||||||
# such as self.running, self.swapped, and self.waiting.
|
# such as self.running, self.swapped, and self.waiting.
|
||||||
scheduler_output = self._schedule()
|
scheduler_outputs, prompt_group_ids = self._schedule()
|
||||||
blocks_to_swap_in = scheduler_output[0]
|
|
||||||
blocks_to_swap_out = scheduler_output[1]
|
|
||||||
blocks_to_copy = scheduler_output[2]
|
|
||||||
prompt_group_ids = scheduler_output[3]
|
|
||||||
|
|
||||||
# Create input data structures.
|
# Create input data structures.
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
updated_seq_groups: List[SequenceGroup] = self.running.copy()
|
|
||||||
|
|
||||||
for seq_group in self.running:
|
for seq_group in self.running:
|
||||||
group_id = seq_group.group_id
|
is_prompt = seq_group.request_id in prompt_group_ids
|
||||||
is_prompt = group_id in prompt_group_ids
|
|
||||||
|
|
||||||
seq_data: Dict[int, List[SequenceData]] = {}
|
seq_data: Dict[int, List[SequenceData]] = {}
|
||||||
block_tables: Dict[int, List[int]] = {}
|
block_tables: Dict[int, List[int]] = {}
|
||||||
@ -255,36 +263,24 @@ class Scheduler:
|
|||||||
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
||||||
|
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
group_id=group_id,
|
request_id=seq_group.request_id,
|
||||||
is_prompt=is_prompt,
|
is_prompt=is_prompt,
|
||||||
seq_data=seq_data,
|
seq_data=seq_data,
|
||||||
sampling_params=self.sampling_params[group_id],
|
sampling_params=seq_group.sampling_params,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
)
|
)
|
||||||
seq_group_metadata_list.append(seq_group_metadata)
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
|
return seq_group_metadata_list, scheduler_outputs
|
||||||
|
|
||||||
# Execute the first stage of the pipeline.
|
def update(
|
||||||
if seq_group_metadata_list or blocks_to_swap_in or blocks_to_swap_out:
|
|
||||||
# Swap in and swap out should never happen at the same time.
|
|
||||||
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
|
||||||
self.controllers[0].execute_stage(
|
|
||||||
seq_group_metadata_list,
|
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
|
||||||
blocks_to_copy=blocks_to_copy,
|
|
||||||
)
|
|
||||||
|
|
||||||
return updated_seq_groups
|
|
||||||
|
|
||||||
def post_step(
|
|
||||||
self,
|
self,
|
||||||
seq_outputs: Dict[int, SequenceOutputs],
|
seq_outputs: Dict[int, SequenceOutputs],
|
||||||
) -> None:
|
) -> List[SequenceGroup]:
|
||||||
# Update the running sequences and free blocks.
|
# Update the running sequences and free blocks.
|
||||||
for seq_group in self.running:
|
for seq_group in self.running:
|
||||||
group_id = seq_group.group_id
|
request_id = seq_group.request_id
|
||||||
self.num_steps[group_id] += 1
|
self.num_steps[request_id] += 1
|
||||||
stop_token_ids = self.sampling_params[group_id].stop_token_ids
|
stop_token_ids = seq_group.sampling_params.stop_token_ids
|
||||||
|
|
||||||
# Process beam search results before processing the next tokens.
|
# Process beam search results before processing the next tokens.
|
||||||
for seq in seq_group.seqs:
|
for seq in seq_group.seqs:
|
||||||
@ -316,12 +312,13 @@ class Scheduler:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if the sequence has reached the maximum number of steps.
|
# Check if the sequence has reached the maximum number of steps.
|
||||||
max_num_steps = self.sampling_params[group_id].max_tokens
|
max_num_steps = seq_group.sampling_params.max_tokens
|
||||||
if self.num_steps[group_id] == max_num_steps:
|
if self.num_steps[request_id] == max_num_steps:
|
||||||
self._free_seq(seq)
|
self._free_seq(seq)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Update the running sequences.
|
# Update the running sequences.
|
||||||
|
updated = self.running.copy()
|
||||||
running: List[SequenceGroup] = []
|
running: List[SequenceGroup] = []
|
||||||
for seq_group in self.running:
|
for seq_group in self.running:
|
||||||
if seq_group.is_finished():
|
if seq_group.is_finished():
|
||||||
@ -329,13 +326,14 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
running.append(seq_group)
|
running.append(seq_group)
|
||||||
self.running = running
|
self.running = running
|
||||||
|
return updated
|
||||||
|
|
||||||
def _allocate(self, seq_group: SequenceGroup) -> None:
|
def _allocate(self, seq_group: SequenceGroup) -> None:
|
||||||
self.block_manager.allocate(seq_group)
|
self.block_manager.allocate(seq_group)
|
||||||
for seq in seq_group.seqs:
|
for seq in seq_group.seqs:
|
||||||
seq.status = SequenceStatus.RUNNING
|
seq.status = SequenceStatus.RUNNING
|
||||||
if seq_group.group_id not in self.num_steps:
|
if seq_group.request_id not in self.num_steps:
|
||||||
self.num_steps[seq_group.group_id] = 0
|
self.num_steps[seq_group.request_id] = 0
|
||||||
|
|
||||||
def _append_slot(
|
def _append_slot(
|
||||||
self,
|
self,
|
||||||
@ -410,9 +408,7 @@ class Scheduler:
|
|||||||
self.block_manager.free(seq)
|
self.block_manager.free(seq)
|
||||||
|
|
||||||
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
|
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||||
group_id = seq_group.group_id
|
del self.num_steps[seq_group.request_id]
|
||||||
del self.num_steps[group_id]
|
|
||||||
del self.sampling_params[group_id]
|
|
||||||
|
|
||||||
def _swap_in(
|
def _swap_in(
|
||||||
self,
|
self,
|
||||||
|
@ -1,302 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import random
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
try:
|
|
||||||
import ray
|
|
||||||
except ImportError:
|
|
||||||
ray = None
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from cacheflow.core.scheduler import Scheduler
|
|
||||||
from cacheflow.frontend.simple_frontend import SimpleFrontend
|
|
||||||
from cacheflow.logger import init_logger
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
from cacheflow.sequence import SequenceGroup
|
|
||||||
from cacheflow.worker.controller import Controller, DeviceID
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Server:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
cache_dir: Optional[str],
|
|
||||||
use_dummy_weights: bool,
|
|
||||||
use_np_cache: bool,
|
|
||||||
pipeline_parallel_size: int,
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
block_size: int,
|
|
||||||
dtype: str,
|
|
||||||
seed: int,
|
|
||||||
swap_space: int,
|
|
||||||
gpu_memory_utilization: float,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
max_num_sequences: int,
|
|
||||||
num_nodes: int,
|
|
||||||
num_devices_per_node: int,
|
|
||||||
distributed_init_method: str,
|
|
||||||
all_stage_devices: List[List[DeviceID]],
|
|
||||||
use_ray: bool,
|
|
||||||
log_stats: bool,
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
"Initializing a server with config: "
|
|
||||||
f"model={model!r}, "
|
|
||||||
f"dtype={dtype}, "
|
|
||||||
f"use_dummy_weights={use_dummy_weights}, "
|
|
||||||
f"cache_dir={cache_dir!r}, "
|
|
||||||
f"use_np_cache={use_np_cache}, "
|
|
||||||
f"tensor_parallel_size={tensor_parallel_size}, "
|
|
||||||
f"seed={seed})"
|
|
||||||
)
|
|
||||||
self.num_nodes = num_nodes
|
|
||||||
self.num_devices_per_node = num_devices_per_node
|
|
||||||
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
|
||||||
|
|
||||||
if not use_ray:
|
|
||||||
assert self.world_size == 1, (
|
|
||||||
"Only support single GPU without Ray.")
|
|
||||||
|
|
||||||
# Create a controller for each pipeline stage.
|
|
||||||
self.controllers: List[Controller] = []
|
|
||||||
for i in range(pipeline_parallel_size):
|
|
||||||
controller = Controller(
|
|
||||||
stage_id=i,
|
|
||||||
stage_devices=all_stage_devices[i],
|
|
||||||
world_size=self.world_size,
|
|
||||||
pipeline_parallel_size=pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
model_name=model,
|
|
||||||
dtype=dtype,
|
|
||||||
seed=seed,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
use_dummy_weights=use_dummy_weights,
|
|
||||||
use_np_cache=use_np_cache,
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
|
||||||
max_num_sequences=max_num_sequences,
|
|
||||||
use_ray=use_ray,
|
|
||||||
)
|
|
||||||
self.controllers.append(controller)
|
|
||||||
|
|
||||||
# Initialize cache engine.
|
|
||||||
all_worker_num_available_blocks = []
|
|
||||||
for controller in self.controllers:
|
|
||||||
all_worker_num_available_blocks.extend(
|
|
||||||
controller.get_num_available_blocks(
|
|
||||||
block_size, swap_space, gpu_memory_utilization)
|
|
||||||
)
|
|
||||||
# Since we use a shared centralized controller, we take the minimum
|
|
||||||
# number of blocks across all workers to make sure all the memory
|
|
||||||
# operators can be applied to all workers.
|
|
||||||
self.num_gpu_blocks = np.min([b[0] for b in all_worker_num_available_blocks])
|
|
||||||
self.num_cpu_blocks = np.min([b[1] for b in all_worker_num_available_blocks])
|
|
||||||
logger.info(f'# GPU blocks: {self.num_gpu_blocks}, '
|
|
||||||
f'# CPU blocks: {self.num_cpu_blocks}')
|
|
||||||
for controller in self.controllers:
|
|
||||||
controller.init_cache_engine(block_size, self.num_gpu_blocks,
|
|
||||||
self.num_cpu_blocks)
|
|
||||||
|
|
||||||
# Create a scheduler.
|
|
||||||
self.scheduler = Scheduler(
|
|
||||||
controllers=self.controllers,
|
|
||||||
block_size=block_size,
|
|
||||||
num_gpu_blocks=self.num_gpu_blocks,
|
|
||||||
num_cpu_blocks=self.num_cpu_blocks,
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
|
||||||
max_num_sequences=max_num_sequences,
|
|
||||||
log_stats=log_stats,
|
|
||||||
)
|
|
||||||
# Connect the controllers.
|
|
||||||
for i in range(len(self.controllers) - 1):
|
|
||||||
self.controllers[i].set_next(self.controllers[i + 1])
|
|
||||||
self.controllers[-1].set_next(self.scheduler)
|
|
||||||
|
|
||||||
def add_sequence_groups(
|
|
||||||
self,
|
|
||||||
sequence_groups: List[Tuple[SequenceGroup, SamplingParams]]
|
|
||||||
):
|
|
||||||
self.scheduler.add_sequence_groups(sequence_groups)
|
|
||||||
|
|
||||||
def step(self):
|
|
||||||
return self.scheduler.step()
|
|
||||||
|
|
||||||
def has_unfinished_requests(self):
|
|
||||||
return (self.scheduler.waiting or self.scheduler.running or
|
|
||||||
self.scheduler.swapped)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_cluster(
|
|
||||||
use_ray: bool = False,
|
|
||||||
address: Optional[str] = None,
|
|
||||||
pipeline_parallel_size: int = 1,
|
|
||||||
tensor_parallel_size: int = 1,
|
|
||||||
) -> Tuple[int, int, str, List[List[DeviceID]]]:
|
|
||||||
# Initialize cluster locally.
|
|
||||||
if not use_ray:
|
|
||||||
assert pipeline_parallel_size * tensor_parallel_size == 1, (
|
|
||||||
"Only support single GPU without Ray.")
|
|
||||||
num_nodes = 1
|
|
||||||
num_devices_per_node = torch.cuda.device_count()
|
|
||||||
port = random.randint(10000, 20000)
|
|
||||||
# We need to setup the distributed init method to make sure
|
|
||||||
# the distributed megatron code (e.g., get world size) works correctly.
|
|
||||||
distributed_init_method = f"tcp://localhost:{port}"
|
|
||||||
all_stage_devices = [[(0, None, 0)]]
|
|
||||||
return (num_nodes, num_devices_per_node, distributed_init_method,
|
|
||||||
all_stage_devices)
|
|
||||||
|
|
||||||
assert ray is not None, (
|
|
||||||
"Ray is not installed. Please install Ray to use distributed "
|
|
||||||
"serving.")
|
|
||||||
|
|
||||||
# Connect to a ray cluster.
|
|
||||||
ray.init(address=address)
|
|
||||||
|
|
||||||
# Assume we have a uniform cluster that each node has the same number of
|
|
||||||
# GPUs for now.
|
|
||||||
valid_node_resources = []
|
|
||||||
num_devices_per_node = None
|
|
||||||
for node in ray.nodes():
|
|
||||||
if (not node['Alive']) or node['Resources']['GPU'] <= 0:
|
|
||||||
continue
|
|
||||||
if num_devices_per_node is None:
|
|
||||||
num_devices_per_node = node['Resources']['GPU']
|
|
||||||
else:
|
|
||||||
assert num_devices_per_node == node['Resources']['GPU'], (
|
|
||||||
"The number of GPUs per node is not uniform.")
|
|
||||||
for key in node['Resources']:
|
|
||||||
if key.startswith('node:'):
|
|
||||||
valid_node_resources.append(key)
|
|
||||||
|
|
||||||
num_nodes = len(valid_node_resources)
|
|
||||||
|
|
||||||
assert (pipeline_parallel_size * tensor_parallel_size
|
|
||||||
<= num_nodes * num_devices_per_node), (
|
|
||||||
"The number of required GPUs exceeds the total number of "
|
|
||||||
"available GPUs.")
|
|
||||||
if tensor_parallel_size >= num_devices_per_node:
|
|
||||||
assert tensor_parallel_size % num_devices_per_node == 0, (
|
|
||||||
"The number of tensor parallelism is not divisible by the "
|
|
||||||
"number of GPUs per node.")
|
|
||||||
else:
|
|
||||||
assert num_devices_per_node % tensor_parallel_size == 0, (
|
|
||||||
"The number of GPUs per node is not divisible by the number "
|
|
||||||
"of tensor parallelism.")
|
|
||||||
|
|
||||||
# Assign GPUs to pipeline stages.
|
|
||||||
rank = 0
|
|
||||||
current_node_id = 0
|
|
||||||
current_device_id = 0
|
|
||||||
distributed_init_method = None
|
|
||||||
all_stage_devices = []
|
|
||||||
|
|
||||||
for i in range(pipeline_parallel_size):
|
|
||||||
stage_devices = []
|
|
||||||
for j in range(tensor_parallel_size):
|
|
||||||
node_resource = valid_node_resources[current_node_id]
|
|
||||||
stage_devices.append((rank, node_resource, current_device_id))
|
|
||||||
if distributed_init_method is None:
|
|
||||||
ip = node_resource.split("node:")[-1]
|
|
||||||
port = random.randint(10000, 20000)
|
|
||||||
distributed_init_method = f"tcp://{ip}:{port}"
|
|
||||||
rank += 1
|
|
||||||
current_device_id += 1
|
|
||||||
if current_device_id >= num_devices_per_node:
|
|
||||||
current_node_id += 1
|
|
||||||
current_device_id = 0
|
|
||||||
all_stage_devices.append(stage_devices)
|
|
||||||
|
|
||||||
return (num_nodes, num_devices_per_node, distributed_init_method,
|
|
||||||
all_stage_devices)
|
|
||||||
|
|
||||||
|
|
||||||
_GiB = 1 << 30
|
|
||||||
|
|
||||||
|
|
||||||
def add_server_arguments(parser: argparse.ArgumentParser):
|
|
||||||
"""Shared arguments for CacheFlow servers."""
|
|
||||||
# Model arguments
|
|
||||||
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
|
|
||||||
parser.add_argument('--cache-dir', type=str, default=None,
|
|
||||||
help='cache dir to download and load the weights, '
|
|
||||||
'default to the default cache dir of huggingface')
|
|
||||||
parser.add_argument('--use-np-cache', action='store_true',
|
|
||||||
help='save a numpy copy of model weights for faster loading')
|
|
||||||
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
|
|
||||||
# TODO(woosuk): Support FP32 for debugging.
|
|
||||||
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
|
|
||||||
help=('data type for model weights and activations. '
|
|
||||||
'The "default" option will use FP16 precision '
|
|
||||||
'for FP32 and FP16 models, and BF16 precision '
|
|
||||||
'for BF16 models.'))
|
|
||||||
# Parallel arguments
|
|
||||||
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
|
|
||||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
|
|
||||||
# KV cache arguments
|
|
||||||
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
|
|
||||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
|
||||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
|
||||||
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
|
||||||
parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor')
|
|
||||||
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
|
|
||||||
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
|
|
||||||
parser.add_argument('--log-stats', action='store_true', help='log system statistics')
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def process_server_arguments(args: argparse.Namespace):
|
|
||||||
"""Post process the parsed arguments."""
|
|
||||||
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
|
|
||||||
args.use_ray = True
|
|
||||||
args.swap_space = args.swap_space * _GiB
|
|
||||||
args.max_num_sequences = min(args.max_num_sequences, args.max_num_batched_tokens)
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def init_local_server_and_frontend_with_arguments(args: argparse.Namespace):
|
|
||||||
# TODO(zhuohan): Support pipeline parallelism.
|
|
||||||
assert args.pipeline_parallel_size == 1, (
|
|
||||||
'Pipeline parallelism is not supported yet.')
|
|
||||||
|
|
||||||
(num_nodes, num_devices_per_node, distributed_init_method,
|
|
||||||
all_stage_devices) = (
|
|
||||||
initialize_cluster(
|
|
||||||
use_ray=args.use_ray,
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size))
|
|
||||||
|
|
||||||
# Create a server.
|
|
||||||
server = Server(
|
|
||||||
model=args.model,
|
|
||||||
cache_dir=args.cache_dir,
|
|
||||||
use_dummy_weights=args.use_dummy_weights,
|
|
||||||
use_np_cache=args.use_np_cache,
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
|
||||||
block_size=args.block_size,
|
|
||||||
dtype=args.dtype,
|
|
||||||
seed=args.seed,
|
|
||||||
swap_space=args.swap_space,
|
|
||||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
|
||||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
|
||||||
max_num_sequences=args.max_num_sequences,
|
|
||||||
num_nodes=num_nodes,
|
|
||||||
num_devices_per_node=num_devices_per_node,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
all_stage_devices=all_stage_devices,
|
|
||||||
use_ray=args.use_ray,
|
|
||||||
log_stats=args.log_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a frontend.
|
|
||||||
frontend = SimpleFrontend(
|
|
||||||
model_name=args.model,
|
|
||||||
block_size=args.block_size,
|
|
||||||
)
|
|
||||||
return server, frontend
|
|
128
cacheflow/entrypoints/fastapi_server.py
Normal file
128
cacheflow/entrypoints/fastapi_server.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import ray
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from cacheflow.outputs import RequestOutput
|
||||||
|
from cacheflow.sampling_params import SamplingParams
|
||||||
|
from cacheflow.server.arg_utils import (
|
||||||
|
add_server_arguments, create_server_configs_from_args)
|
||||||
|
from cacheflow.server.llm_server import LLMServer
|
||||||
|
from cacheflow.server.ray_utils import initialize_cluster
|
||||||
|
|
||||||
|
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
class FastAPIServer:
|
||||||
|
|
||||||
|
def __init__(self, server_use_ray: bool, *args, **kwargs) -> None:
|
||||||
|
if server_use_ray:
|
||||||
|
remote_server_class = ray.remote(num_cpus=0)(LLMServer)
|
||||||
|
else:
|
||||||
|
remote_server_class = ray.remote(num_gpus=1)(LLMServer)
|
||||||
|
self.server = remote_server_class.remote(*args, **kwargs)
|
||||||
|
|
||||||
|
# Request id -> request output.
|
||||||
|
self.request_outputs: Dict[str, RequestOutput] = {}
|
||||||
|
# Request id -> event to notify that there is new output.
|
||||||
|
self.request_events: Dict[str, asyncio.Event] = {}
|
||||||
|
self.is_server_running = False
|
||||||
|
|
||||||
|
async def server_step(self):
|
||||||
|
self.is_server_running = True
|
||||||
|
request_outputs = await self.server.step.remote()
|
||||||
|
self.is_server_running = False
|
||||||
|
# Notify the waiting coroutines that there are new outputs ready.
|
||||||
|
for request_output in request_outputs:
|
||||||
|
request_id = request_output.request_id
|
||||||
|
self.request_outputs[request_id] = request_output
|
||||||
|
self.request_events[request_id].set()
|
||||||
|
|
||||||
|
async def generate(self, request_dict: Dict[str, Any]):
|
||||||
|
# Preprocess the request.
|
||||||
|
arrival_time = time.time()
|
||||||
|
prompt = request_dict.pop("prompt")
|
||||||
|
sampling_params = SamplingParams(**request_dict)
|
||||||
|
|
||||||
|
# Create an event to notify us that there is new output from the
|
||||||
|
# cacheflow server.
|
||||||
|
request_id = str(uuid.uuid4().hex[:8])
|
||||||
|
request_event = asyncio.Event()
|
||||||
|
self.request_events[request_id] = request_event
|
||||||
|
|
||||||
|
# Add the request into the cacheflow server's waiting queue.
|
||||||
|
await self.server.add_request.remote(
|
||||||
|
request_id, prompt, sampling_params, arrival_time=arrival_time)
|
||||||
|
|
||||||
|
# The cacheflow server does not have a background loop that keeps
|
||||||
|
# processing incoming requests. Therefore, we need to keep kicking
|
||||||
|
# the server to process the requests.
|
||||||
|
while True:
|
||||||
|
# Kick the server if the server is not running.
|
||||||
|
if not self.is_server_running:
|
||||||
|
await self.server_step()
|
||||||
|
|
||||||
|
# Wait for new output. The group_event will be set in server_step
|
||||||
|
# when there is new output available for the sequence group.
|
||||||
|
# Added a timeout to prevent deadlock.
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(request_event.wait(),
|
||||||
|
timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
# Reset the event to wait for the next output.
|
||||||
|
request_event.clear()
|
||||||
|
|
||||||
|
# Decode and return new outputs.
|
||||||
|
request_output = self.request_outputs[request_id]
|
||||||
|
prompt = request_output.prompt
|
||||||
|
text_outputs = [
|
||||||
|
prompt + output.text
|
||||||
|
for output in request_output.outputs
|
||||||
|
]
|
||||||
|
ret = {
|
||||||
|
"text": text_outputs,
|
||||||
|
"error": 0,
|
||||||
|
}
|
||||||
|
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||||
|
|
||||||
|
# Once finished, release the resources of the sequence group.
|
||||||
|
if request_output.done:
|
||||||
|
del self.request_outputs[request_id]
|
||||||
|
del self.request_events[request_id]
|
||||||
|
# Kick the server if the server is not running. This is to
|
||||||
|
# prevent that there are still requests in server's waiting
|
||||||
|
# queue to be executed.
|
||||||
|
if not self.is_server_running:
|
||||||
|
await self.server_step()
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate")
|
||||||
|
async def generate_stream(request: Request):
|
||||||
|
request_dict = await request.json()
|
||||||
|
return StreamingResponse(server.generate(request_dict))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=int, default=10002)
|
||||||
|
parser = add_server_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
server_configs = create_server_configs_from_args(args)
|
||||||
|
parallel_config = server_configs[2]
|
||||||
|
distributed_init_method, stage_devices = initialize_cluster(parallel_config)
|
||||||
|
|
||||||
|
server = FastAPIServer(
|
||||||
|
args.use_ray, *server_configs, distributed_init_method, stage_devices)
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
@ -1,201 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from typing import List, Dict, Optional
|
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
import ray
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
from cacheflow.core.server import (Server, add_server_arguments,
|
|
||||||
process_server_arguments,
|
|
||||||
initialize_cluster)
|
|
||||||
from cacheflow.frontend.utils import get_tokenizer
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
from cacheflow.sequence import Sequence, SequenceGroup
|
|
||||||
from cacheflow.utils import Counter
|
|
||||||
from cacheflow.worker.controller import DeviceID
|
|
||||||
|
|
||||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
|
|
||||||
class FastAPIServer:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
cache_dir: Optional[str],
|
|
||||||
use_np_cache: bool,
|
|
||||||
pipeline_parallel_size: int,
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
block_size: int,
|
|
||||||
dtype: str,
|
|
||||||
seed: int,
|
|
||||||
swap_space: int,
|
|
||||||
gpu_memory_utilization: float,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
max_num_sequences: int,
|
|
||||||
num_nodes: int,
|
|
||||||
num_devices_per_node: int,
|
|
||||||
distributed_init_method: str,
|
|
||||||
all_stage_devices: List[List[DeviceID]],
|
|
||||||
server_use_ray: bool,
|
|
||||||
log_stats: bool,
|
|
||||||
):
|
|
||||||
self.block_size = block_size
|
|
||||||
|
|
||||||
self.tokenizer = get_tokenizer(model)
|
|
||||||
self.seq_group_counter = Counter()
|
|
||||||
self.seq_counter = Counter()
|
|
||||||
if server_use_ray:
|
|
||||||
remote_server_class = ray.remote(num_cpus=0)(Server)
|
|
||||||
else:
|
|
||||||
remote_server_class = ray.remote(num_gpus=1)(Server)
|
|
||||||
self.server = remote_server_class.remote(
|
|
||||||
model=model,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
use_dummy_weights=False,
|
|
||||||
use_np_cache=use_np_cache,
|
|
||||||
pipeline_parallel_size=pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
block_size=block_size,
|
|
||||||
dtype=dtype,
|
|
||||||
seed=seed,
|
|
||||||
swap_space=swap_space,
|
|
||||||
gpu_memory_utilization=gpu_memory_utilization,
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
|
||||||
max_num_sequences=max_num_sequences,
|
|
||||||
num_nodes=num_nodes,
|
|
||||||
num_devices_per_node=num_devices_per_node,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
all_stage_devices=all_stage_devices,
|
|
||||||
use_ray=server_use_ray,
|
|
||||||
log_stats=log_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.running_seq_groups: Dict[int, SequenceGroup] = {}
|
|
||||||
self.sequence_group_events: Dict[int, asyncio.Event] = {}
|
|
||||||
self.is_server_running = False
|
|
||||||
|
|
||||||
async def server_step(self):
|
|
||||||
self.is_server_running = True
|
|
||||||
updated_seq_groups = await self.server.step.remote()
|
|
||||||
self.is_server_running = False
|
|
||||||
# Notify the waiting coroutines that there are new outputs ready.
|
|
||||||
for seq_group in updated_seq_groups:
|
|
||||||
group_id = seq_group.group_id
|
|
||||||
self.running_seq_groups[group_id] = seq_group
|
|
||||||
self.sequence_group_events[group_id].set()
|
|
||||||
|
|
||||||
async def generate(self, request_dict: Dict):
|
|
||||||
# Preprocess the request.
|
|
||||||
prompt = request_dict.pop("prompt")
|
|
||||||
sampling_params = SamplingParams(**request_dict)
|
|
||||||
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
|
|
||||||
token_ids = self.tokenizer.encode(prompt)
|
|
||||||
seqs: List[Sequence] = []
|
|
||||||
for _ in range(sampling_params.n):
|
|
||||||
seq_id = next(self.seq_counter)
|
|
||||||
seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size)
|
|
||||||
seqs.append(seq)
|
|
||||||
|
|
||||||
arrival_time = time.time()
|
|
||||||
group_id = next(self.seq_group_counter)
|
|
||||||
seq_group = SequenceGroup(group_id, seqs, arrival_time)
|
|
||||||
# Create an event to notify us that there is new output from the
|
|
||||||
# cacheflow server.
|
|
||||||
group_event = asyncio.Event()
|
|
||||||
self.running_seq_groups[group_id] = seq_group
|
|
||||||
self.sequence_group_events[group_id] = group_event
|
|
||||||
# Add the request into the cacheflow server's waiting queue.
|
|
||||||
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
|
|
||||||
# The cacheflow server does not have a background loop that keeps
|
|
||||||
# processing incoming requests. Therefore, we need to keep kicking
|
|
||||||
# the server to process the requests.
|
|
||||||
while True:
|
|
||||||
# Kick the server if the server is not running.
|
|
||||||
if not self.is_server_running:
|
|
||||||
await self.server_step()
|
|
||||||
# Wait for new output. The group_event will be set in server_step
|
|
||||||
# when there is new output available for the sequence group.
|
|
||||||
# Added a timeout to prevent deadlock.
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue
|
|
||||||
# Reset the event to wait for the next output.
|
|
||||||
group_event.clear()
|
|
||||||
# Decode and return new outputs
|
|
||||||
seq_group = self.running_seq_groups[group_id]
|
|
||||||
all_outputs = []
|
|
||||||
for seq in seq_group.seqs:
|
|
||||||
token_ids = seq.get_token_ids()
|
|
||||||
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
|
||||||
all_outputs.append(output)
|
|
||||||
ret = {
|
|
||||||
"text": all_outputs,
|
|
||||||
"error": 0,
|
|
||||||
}
|
|
||||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
|
||||||
|
|
||||||
# Once finished, release the resources of the sequence group.
|
|
||||||
if seq_group.is_finished():
|
|
||||||
del self.running_seq_groups[group_id]
|
|
||||||
del self.sequence_group_events[group_id]
|
|
||||||
# Kick the server if the server is not running. This is to
|
|
||||||
# prevent that there are still requests in server's waiting
|
|
||||||
# queue to be executed.
|
|
||||||
if not self.is_server_running:
|
|
||||||
await self.server_step()
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate")
|
|
||||||
async def generate_stream(request: Request):
|
|
||||||
request_dict = await request.json()
|
|
||||||
return StreamingResponse(server.generate(request_dict))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
|
||||||
parser.add_argument("--port", type=int, default=10002)
|
|
||||||
parser = add_server_arguments(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
args = process_server_arguments(args)
|
|
||||||
|
|
||||||
# TODO(zhuohan): Support pipeline parallelism.
|
|
||||||
assert args.pipeline_parallel_size == 1, (
|
|
||||||
'Pipeline parallelism is not supported yet.')
|
|
||||||
|
|
||||||
(num_nodes, num_devices_per_node, distributed_init_method,
|
|
||||||
all_stage_devices) = (
|
|
||||||
initialize_cluster(
|
|
||||||
use_ray=True,
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size))
|
|
||||||
|
|
||||||
server = FastAPIServer(
|
|
||||||
model=args.model,
|
|
||||||
cache_dir=args.cache_dir,
|
|
||||||
use_np_cache=args.use_np_cache,
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
|
||||||
block_size=args.block_size,
|
|
||||||
dtype=args.dtype,
|
|
||||||
seed=args.seed,
|
|
||||||
swap_space=args.swap_space,
|
|
||||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
|
||||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
|
||||||
max_num_sequences=args.max_num_sequences,
|
|
||||||
num_nodes=num_nodes,
|
|
||||||
num_devices_per_node=num_devices_per_node,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
all_stage_devices=all_stage_devices,
|
|
||||||
server_use_ray=args.use_ray,
|
|
||||||
log_stats=args.log_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
|
@ -1,72 +0,0 @@
|
|||||||
import time
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from cacheflow.frontend.utils import get_tokenizer
|
|
||||||
from cacheflow.logger import init_logger
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
from cacheflow.sequence import Sequence, SequenceGroup
|
|
||||||
from cacheflow.utils import Counter
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleFrontend:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
block_size: int,
|
|
||||||
) -> None:
|
|
||||||
self.block_size = block_size
|
|
||||||
|
|
||||||
self.tokenizer = get_tokenizer(model_name)
|
|
||||||
self.seq_group_counter = Counter()
|
|
||||||
self.seq_counter = Counter()
|
|
||||||
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
|
|
||||||
|
|
||||||
def add_eos_token(self, sampling_params: SamplingParams) -> SamplingParams:
|
|
||||||
# Stop generation when we see an EOS token.
|
|
||||||
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
|
|
||||||
return sampling_params
|
|
||||||
|
|
||||||
def query(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
sampling_params: SamplingParams,
|
|
||||||
) -> None:
|
|
||||||
token_ids = self.tokenizer.encode(prompt)
|
|
||||||
self._add_query(prompt, token_ids, sampling_params)
|
|
||||||
|
|
||||||
def _add_query(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
token_ids: List[int],
|
|
||||||
sampling_params: SamplingParams,
|
|
||||||
arrival_time: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
if arrival_time is None:
|
|
||||||
arrival_time = time.time()
|
|
||||||
seqs: List[Sequence] = []
|
|
||||||
for _ in range(sampling_params.n):
|
|
||||||
seq_id = next(self.seq_counter)
|
|
||||||
seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size)
|
|
||||||
seqs.append(seq)
|
|
||||||
|
|
||||||
group_id = next(self.seq_group_counter)
|
|
||||||
seq_group = SequenceGroup(group_id, seqs, arrival_time)
|
|
||||||
self.inputs.append((seq_group, sampling_params))
|
|
||||||
|
|
||||||
def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]:
|
|
||||||
inputs = self.inputs
|
|
||||||
self.inputs = []
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
def print_response(
|
|
||||||
self,
|
|
||||||
seq_group: SequenceGroup,
|
|
||||||
) -> None:
|
|
||||||
for seq in seq_group.seqs:
|
|
||||||
token_ids = seq.get_token_ids()
|
|
||||||
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
|
||||||
output = output.strip()
|
|
||||||
logger.info(f"Seq {seq.seq_id}: {output!r}")
|
|
@ -1,12 +1,10 @@
|
|||||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||||
from cacheflow.model_executor.model_loader import get_model
|
from cacheflow.model_executor.model_loader import get_model
|
||||||
from cacheflow.model_executor.utils import (set_random_seed,
|
from cacheflow.model_executor.utils import set_random_seed
|
||||||
get_cache_block_size)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InputMetadata",
|
"InputMetadata",
|
||||||
"get_cache_block_size",
|
|
||||||
"get_model",
|
"get_model",
|
||||||
"set_random_seed",
|
"set_random_seed",
|
||||||
]
|
]
|
||||||
|
@ -10,9 +10,9 @@ from cacheflow import cache_ops
|
|||||||
from cacheflow import pos_encoding_ops
|
from cacheflow import pos_encoding_ops
|
||||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]
|
_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]
|
||||||
|
|
||||||
|
|
||||||
class GPTCacheFlowAttention(nn.Module):
|
class GPTCacheFlowAttention(nn.Module):
|
||||||
"""GPT-style multi-head attention.
|
"""GPT-style multi-head attention.
|
||||||
|
|
||||||
|
@ -1,16 +1,13 @@
|
|||||||
"""Utilities for selecting and loading models."""
|
"""Utilities for selecting and loading models."""
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import AutoConfig, PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from cacheflow.config import ModelConfig
|
||||||
from cacheflow.model_executor.models import (
|
from cacheflow.model_executor.models import (
|
||||||
GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
|
GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
|
||||||
from cacheflow.model_executor.utils import get_torch_dtype
|
|
||||||
from cacheflow.model_executor.weight_utils import initialize_dummy_weights
|
from cacheflow.model_executor.weight_utils import initialize_dummy_weights
|
||||||
|
|
||||||
|
|
||||||
# TODO(woosuk): Lazy-load the model classes.
|
# TODO(woosuk): Lazy-load the model classes.
|
||||||
_MODEL_REGISTRY = {
|
_MODEL_REGISTRY = {
|
||||||
"GPT2LMHeadModel": GPT2LMHeadModel,
|
"GPT2LMHeadModel": GPT2LMHeadModel,
|
||||||
@ -19,6 +16,7 @@ _MODEL_REGISTRY = {
|
|||||||
"OPTForCausalLM": OPTForCausalLM,
|
"OPTForCausalLM": OPTForCausalLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
|
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
|
||||||
architectures = getattr(config, "architectures", [])
|
architectures = getattr(config, "architectures", [])
|
||||||
for arch in architectures:
|
for arch in architectures:
|
||||||
@ -30,51 +28,22 @@ def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
|
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
model_class = _get_model_architecture(model_config.hf_config)
|
||||||
# because config.torch_dtype can be None.
|
torch.set_default_dtype(model_config.dtype)
|
||||||
config_dtype = getattr(config, "torch_dtype", None)
|
|
||||||
if config_dtype is None:
|
|
||||||
config_dtype = torch.float32
|
|
||||||
if dtype == "default":
|
|
||||||
if config_dtype == torch.float32:
|
|
||||||
# Following the common practice, we use float16 for float32 models.
|
|
||||||
torch_dtype = torch.float16
|
|
||||||
else:
|
|
||||||
torch_dtype = config_dtype
|
|
||||||
else:
|
|
||||||
torch_dtype = get_torch_dtype(dtype)
|
|
||||||
if torch_dtype != config_dtype and config_dtype != torch.float32:
|
|
||||||
# TODO(woosuk): Allow using float16 for bfloat16 models and
|
|
||||||
# vice versa. Print a warning message and continue.
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot use {torch_dtype} for {config_dtype} model.")
|
|
||||||
return torch_dtype
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
model_name: str,
|
|
||||||
dtype: str,
|
|
||||||
cache_dir: Optional[str],
|
|
||||||
use_dummy_weights: bool,
|
|
||||||
use_np_cache: bool,
|
|
||||||
) -> nn.Module:
|
|
||||||
config = AutoConfig.from_pretrained(model_name)
|
|
||||||
torch_dtype = _get_dtype(config, dtype)
|
|
||||||
torch.set_default_dtype(torch_dtype)
|
|
||||||
model_class = _get_model_architecture(config)
|
|
||||||
|
|
||||||
# Create a model instance.
|
# Create a model instance.
|
||||||
# The weights will be initialized as empty tensors.
|
# The weights will be initialized as empty tensors.
|
||||||
model = model_class(config)
|
model = model_class(model_config.hf_config)
|
||||||
if use_dummy_weights:
|
if model_config.use_dummy_weights:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
initialize_dummy_weights(model)
|
initialize_dummy_weights(model)
|
||||||
else:
|
else:
|
||||||
# Load the weights from the cached or downloaded files.
|
# Load the weights from the cached or downloaded files.
|
||||||
model.load_weights(model_name, cache_dir, use_np_cache)
|
model.load_weights(
|
||||||
|
model_config.model, model_config.download_dir,
|
||||||
|
model_config.use_np_weights)
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
return model.eval(), torch_dtype
|
return model.eval()
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
"""Utils for model executor."""
|
"""Utils for model executor."""
|
||||||
import random
|
import random
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -9,28 +8,6 @@ from cacheflow.model_executor.parallel_utils.parallel_state import model_paralle
|
|||||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
|
from cacheflow.model_executor.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
|
||||||
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
||||||
"half": torch.half,
|
|
||||||
"float": torch.float,
|
|
||||||
"float16": torch.float16,
|
|
||||||
"float32": torch.float32,
|
|
||||||
"bfloat16": torch.bfloat16,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
|
|
||||||
if isinstance(dtype, str):
|
|
||||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
|
|
||||||
else:
|
|
||||||
torch_dtype = dtype
|
|
||||||
return torch_dtype
|
|
||||||
|
|
||||||
|
|
||||||
def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
|
|
||||||
torch_dtype = get_torch_dtype(dtype)
|
|
||||||
return torch.tensor([], dtype=torch_dtype).element_size()
|
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed: int) -> None:
|
def set_random_seed(seed: int) -> None:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
@ -40,15 +17,3 @@ def set_random_seed(seed: int) -> None:
|
|||||||
|
|
||||||
if model_parallel_is_initialized():
|
if model_parallel_is_initialized():
|
||||||
model_parallel_cuda_manual_seed(seed)
|
model_parallel_cuda_manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
def get_cache_block_size(block_size: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
num_layers: int,
|
|
||||||
dtype: str) -> int:
|
|
||||||
key_cache_block = block_size * num_heads * head_size
|
|
||||||
value_cache_block = key_cache_block
|
|
||||||
total = num_layers * (key_cache_block + value_cache_block)
|
|
||||||
dtype_size = get_dtype_size(dtype)
|
|
||||||
return dtype_size * total
|
|
||||||
|
79
cacheflow/outputs.py
Normal file
79
cacheflow/outputs.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
from cacheflow.sequence import SequenceGroup
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionOutput:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
token_ids: List[int],
|
||||||
|
cumulative_logprobs: float,
|
||||||
|
logprobs: List[Dict[int, float]],
|
||||||
|
) -> None:
|
||||||
|
self.text = text
|
||||||
|
self.token_ids = token_ids
|
||||||
|
self.cumulative_logprobs = cumulative_logprobs
|
||||||
|
self.logprobs = logprobs
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"CompletionOutput(output={self.text!r}, "
|
||||||
|
f"token_ids={self.token_ids}, "
|
||||||
|
f"cumulative_logprobs={self.cumulative_logprobs}, "
|
||||||
|
f"logprobs={self.logprobs})")
|
||||||
|
|
||||||
|
|
||||||
|
class RequestOutput:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
request_id: int,
|
||||||
|
prompt: str,
|
||||||
|
prompt_token_ids: List[int],
|
||||||
|
outputs: List[CompletionOutput],
|
||||||
|
done: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self.prompt = prompt
|
||||||
|
self.prompt_token_ids = prompt_token_ids
|
||||||
|
self.outputs = outputs
|
||||||
|
self.done = done
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_seq_group(
|
||||||
|
seq_group: SequenceGroup,
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
) -> "RequestOutput":
|
||||||
|
outputs: List[CompletionOutput] = []
|
||||||
|
seqs = seq_group.get_seqs()
|
||||||
|
for seq in seqs:
|
||||||
|
output_token_ids = seq.data.output_token_ids
|
||||||
|
output_str = tokenizer.decode(output_token_ids,
|
||||||
|
skip_special_tokens=True)
|
||||||
|
seq_logprobs = seq.data.cumulative_logprobs
|
||||||
|
|
||||||
|
logprobs = seq.output_logprobs
|
||||||
|
if seq_group.sampling_params.logprobs == 0:
|
||||||
|
# NOTE: We need to take care of this case because the sequence
|
||||||
|
# always has the logprobs of the sampled tokens even if the
|
||||||
|
# logprobs are not requested.
|
||||||
|
logprobs = {}
|
||||||
|
output = CompletionOutput(output_str, output_token_ids,
|
||||||
|
seq_logprobs, logprobs)
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
# Every sequence in the sequence group should have the same prompt.
|
||||||
|
prompt = seqs[0].prompt
|
||||||
|
prompt_token_ids = seqs[0].data.prompt_token_ids
|
||||||
|
return RequestOutput(seq_group.request_id, prompt, prompt_token_ids,
|
||||||
|
outputs, seq_group.is_finished())
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"RequestOutput(request_id={self.request_id}, "
|
||||||
|
f"prompt={self.prompt!r}, "
|
||||||
|
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||||
|
f"outputs={self.outputs}, "
|
||||||
|
f"done={self.done})")
|
@ -116,4 +116,4 @@ class SamplingParams:
|
|||||||
f"use_beam_search={self.use_beam_search}, "
|
f"use_beam_search={self.use_beam_search}, "
|
||||||
f"stop_token_ids={self.stop_token_ids}, "
|
f"stop_token_ids={self.stop_token_ids}, "
|
||||||
f"max_tokens={self.max_tokens}, "
|
f"max_tokens={self.max_tokens}, "
|
||||||
f"logprobs={self.logprobs}")
|
f"logprobs={self.logprobs})")
|
||||||
|
@ -115,12 +115,14 @@ class SequenceGroup:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
group_id: int,
|
request_id: str,
|
||||||
seqs: List[Sequence],
|
seqs: List[Sequence],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
arrival_time: float,
|
arrival_time: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.group_id = group_id
|
self.request_id = request_id
|
||||||
self.seqs = seqs
|
self.seqs = seqs
|
||||||
|
self.sampling_params = sampling_params
|
||||||
self.arrival_time = arrival_time
|
self.arrival_time = arrival_time
|
||||||
|
|
||||||
def get_seqs(
|
def get_seqs(
|
||||||
@ -145,21 +147,22 @@ class SequenceGroup:
|
|||||||
return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs)
|
return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f'SequenceGroup(group_id={self.group_id}, '
|
return (f"SequenceGroup(request_id={self.request_id}, "
|
||||||
f'num_seqs={len(self.seqs)})')
|
f"sampling_params={self.sampling_params}, "
|
||||||
|
f"num_seqs={len(self.seqs)})")
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroupMetadata:
|
class SequenceGroupMetadata:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
group_id: int,
|
request_id: str,
|
||||||
is_prompt: bool,
|
is_prompt: bool,
|
||||||
seq_data: Dict[int, SequenceData], # Seq id -> sequence data.
|
seq_data: Dict[int, SequenceData], # Seq id -> sequence data.
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers.
|
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers.
|
||||||
) -> None:
|
) -> None:
|
||||||
self.group_id = group_id
|
self.request_id = request_id
|
||||||
self.is_prompt = is_prompt
|
self.is_prompt = is_prompt
|
||||||
self.seq_data = seq_data
|
self.seq_data = seq_data
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
|
74
cacheflow/server/arg_utils.py
Normal file
74
cacheflow/server/arg_utils.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import argparse
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
|
SchedulerConfig)
|
||||||
|
from cacheflow.server.llm_server import LLMServer
|
||||||
|
from cacheflow.server.ray_utils import initialize_cluster
|
||||||
|
|
||||||
|
_GiB = 1 << 30
|
||||||
|
|
||||||
|
|
||||||
|
def add_server_arguments(parser: argparse.ArgumentParser):
|
||||||
|
"""Shared arguments for CacheFlow servers."""
|
||||||
|
# Model arguments
|
||||||
|
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
|
||||||
|
parser.add_argument('--download-dir', type=str, default=None,
|
||||||
|
help='directory to download and load the weights, '
|
||||||
|
'default to the default cache dir of huggingface')
|
||||||
|
parser.add_argument('--use-np-weights', action='store_true',
|
||||||
|
help='save a numpy copy of model weights for faster loading')
|
||||||
|
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
|
||||||
|
# TODO(woosuk): Support FP32.
|
||||||
|
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
|
||||||
|
help=('data type for model weights and activations. '
|
||||||
|
'The "default" option will use FP16 precision '
|
||||||
|
'for FP32 and FP16 models, and BF16 precision '
|
||||||
|
'for BF16 models.'))
|
||||||
|
# Parallel arguments
|
||||||
|
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
|
||||||
|
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
||||||
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
|
||||||
|
# KV cache arguments
|
||||||
|
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
|
||||||
|
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||||
|
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||||
|
parser.add_argument('--swap-space', type=int, default=4, help='CPU swap space size (GiB) per GPU')
|
||||||
|
parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor')
|
||||||
|
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
|
||||||
|
parser.add_argument('--max-num-seqs', type=int, default=256, help='maximum number of sequences per iteration')
|
||||||
|
parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics')
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def create_server_configs_from_args(
|
||||||
|
args: argparse.Namespace,
|
||||||
|
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||||
|
# Post-process the parsed arguments.
|
||||||
|
args.swap_space = args.swap_space * _GiB
|
||||||
|
args.max_num_seqs = min(args.max_num_seqs, args.max_num_batched_tokens)
|
||||||
|
|
||||||
|
# Initialize the configs.
|
||||||
|
model_config = ModelConfig(
|
||||||
|
args.model, args.download_dir, args.use_np_weights,
|
||||||
|
args.use_dummy_weights, args.dtype, args.seed)
|
||||||
|
cache_config = CacheConfig(args.block_size, args.gpu_memory_utilization,
|
||||||
|
args.swap_space)
|
||||||
|
parallel_config = ParallelConfig(args.pipeline_parallel_size,
|
||||||
|
args.tensor_parallel_size, args.use_ray)
|
||||||
|
scheduler_config = SchedulerConfig(args.max_num_batched_tokens,
|
||||||
|
args.max_num_seqs)
|
||||||
|
return model_config, cache_config, parallel_config, scheduler_config
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_server_from_args(args: argparse.Namespace) -> LLMServer:
|
||||||
|
server_configs = create_server_configs_from_args(args)
|
||||||
|
parallel_config = server_configs[2]
|
||||||
|
|
||||||
|
# Initialize the cluster.
|
||||||
|
distributed_init_method, devices = initialize_cluster(parallel_config)
|
||||||
|
|
||||||
|
# Create the LLM server.
|
||||||
|
server = LLMServer(*server_configs, distributed_init_method, devices,
|
||||||
|
log_stats=not args.disable_log_stats)
|
||||||
|
return server
|
198
cacheflow/server/llm_server.py
Normal file
198
cacheflow/server/llm_server.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
import time
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ray
|
||||||
|
except ImportError:
|
||||||
|
ray = None
|
||||||
|
|
||||||
|
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
|
SchedulerConfig)
|
||||||
|
from cacheflow.core.scheduler import Scheduler
|
||||||
|
from cacheflow.logger import init_logger
|
||||||
|
from cacheflow.outputs import RequestOutput
|
||||||
|
from cacheflow.sampling_params import SamplingParams
|
||||||
|
from cacheflow.server.tokenizer_utils import get_tokenizer
|
||||||
|
from cacheflow.sequence import Sequence, SequenceGroup
|
||||||
|
from cacheflow.utils import Counter
|
||||||
|
from cacheflow.worker.worker import Worker
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMServer:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
distributed_init_method: str,
|
||||||
|
stage_devices: List[List[Any]],
|
||||||
|
log_stats: bool = True,
|
||||||
|
) -> None:
|
||||||
|
logger.info(
|
||||||
|
"Initializing an LLM server with config: "
|
||||||
|
f"model={model_config.model!r}, "
|
||||||
|
f"dtype={model_config.dtype}, "
|
||||||
|
f"use_dummy_weights={model_config.use_dummy_weights}, "
|
||||||
|
f"download_dir={model_config.download_dir!r}, "
|
||||||
|
f"use_np_weights={model_config.use_np_weights}, "
|
||||||
|
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||||
|
f"seed={model_config.seed})"
|
||||||
|
)
|
||||||
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
|
|
||||||
|
self.model_config = model_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.log_stats = log_stats
|
||||||
|
|
||||||
|
self._verify_args()
|
||||||
|
|
||||||
|
self.tokenizer = get_tokenizer(model_config.model)
|
||||||
|
self.seq_counter = Counter()
|
||||||
|
|
||||||
|
# Create the parallel GPU workers.
|
||||||
|
self.workers: List[Worker] = []
|
||||||
|
assert len(stage_devices) == 1, "Only support one stage for now."
|
||||||
|
for rank, node_resource, _ in stage_devices[0]:
|
||||||
|
worker_cls = Worker
|
||||||
|
if self.parallel_config.use_ray:
|
||||||
|
worker_cls = ray.remote(
|
||||||
|
num_cpus=0,
|
||||||
|
num_gpus=1,
|
||||||
|
resources={node_resource: 1e-5},
|
||||||
|
)(worker_cls).remote
|
||||||
|
|
||||||
|
worker = worker_cls(
|
||||||
|
model_config,
|
||||||
|
parallel_config,
|
||||||
|
scheduler_config,
|
||||||
|
rank,
|
||||||
|
distributed_init_method,
|
||||||
|
)
|
||||||
|
self.workers.append(worker)
|
||||||
|
# Profile the memory usage and initialize the cache.
|
||||||
|
self._init_cache()
|
||||||
|
|
||||||
|
# Create the scheduler.
|
||||||
|
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
|
||||||
|
|
||||||
|
def _verify_args(self) -> None:
|
||||||
|
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||||
|
|
||||||
|
def _init_cache(self) -> None:
|
||||||
|
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||||
|
num_blocks = self._run_workers(
|
||||||
|
"profile_num_available_blocks",
|
||||||
|
get_all_outputs=True,
|
||||||
|
block_size=self.cache_config.block_size,
|
||||||
|
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
||||||
|
cpu_swap_space=self.cache_config.swap_space,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since we use a shared centralized controller, we take the minimum
|
||||||
|
# number of blocks across all workers to make sure all the memory
|
||||||
|
# operators can be applied to all workers.
|
||||||
|
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||||
|
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||||
|
# FIXME(woosuk): Change to debug log.
|
||||||
|
logger.info(f'# GPU blocks: {num_gpu_blocks}, '
|
||||||
|
f'# CPU blocks: {num_cpu_blocks}')
|
||||||
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
|
# Initialize the cache.
|
||||||
|
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
||||||
|
|
||||||
|
def add_request(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: str,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
if arrival_time is None:
|
||||||
|
arrival_time = time.time()
|
||||||
|
if prompt_token_ids is None:
|
||||||
|
prompt_token_ids = self.tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
# Create the sequences.
|
||||||
|
block_size = self.cache_config.block_size
|
||||||
|
seqs: List[Sequence] = []
|
||||||
|
for _ in range(sampling_params.n):
|
||||||
|
seq_id = next(self.seq_counter)
|
||||||
|
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||||
|
seqs.append(seq)
|
||||||
|
|
||||||
|
# FIXME(woosuk)
|
||||||
|
# Add the EOS token to the stop token list.
|
||||||
|
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
|
||||||
|
|
||||||
|
# Create the sequence group.
|
||||||
|
seq_group = SequenceGroup(request_id, seqs, sampling_params,
|
||||||
|
arrival_time)
|
||||||
|
|
||||||
|
# Add the sequence group to the scheduler.
|
||||||
|
self.scheduler.add_seq_group(seq_group)
|
||||||
|
|
||||||
|
def has_unfinished_requests(self) -> bool:
|
||||||
|
return self.scheduler.has_unfinished_seqs()
|
||||||
|
|
||||||
|
def step(self) -> List[RequestOutput]:
|
||||||
|
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||||
|
if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
|
||||||
|
# Nothing to do.
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Execute the model.
|
||||||
|
output = self._run_workers(
|
||||||
|
"execute_model",
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||||
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||||
|
)
|
||||||
|
# Update the scheduler.
|
||||||
|
updated_seq_groups = self.scheduler.update(output)
|
||||||
|
|
||||||
|
# Create the outputs.
|
||||||
|
request_outputs: List[RequestOutput] = []
|
||||||
|
for seq_group in updated_seq_groups:
|
||||||
|
# TODO(woosuk): Batch-decode the outputs for speedup.
|
||||||
|
request_output = RequestOutput.from_seq_group(seq_group,
|
||||||
|
self.tokenizer)
|
||||||
|
request_outputs.append(request_output)
|
||||||
|
return request_outputs
|
||||||
|
|
||||||
|
def _run_workers(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
get_all_outputs: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
all_outputs = []
|
||||||
|
for worker in self.workers:
|
||||||
|
executor = getattr(worker, method)
|
||||||
|
if self.parallel_config.use_ray:
|
||||||
|
executor = executor.remote
|
||||||
|
|
||||||
|
output = executor(*args, **kwargs)
|
||||||
|
all_outputs.append(output)
|
||||||
|
|
||||||
|
if self.parallel_config.use_ray:
|
||||||
|
all_outputs = ray.get(all_outputs)
|
||||||
|
|
||||||
|
if get_all_outputs:
|
||||||
|
return all_outputs
|
||||||
|
|
||||||
|
# Make sure all workers have the same results.
|
||||||
|
output = all_outputs[0]
|
||||||
|
for other_output in all_outputs[1:]:
|
||||||
|
assert output == other_output
|
||||||
|
return output
|
90
cacheflow/server/ray_utils.py
Normal file
90
cacheflow/server/ray_utils.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
import random
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ray
|
||||||
|
except ImportError:
|
||||||
|
ray = None
|
||||||
|
|
||||||
|
from cacheflow.config import ParallelConfig
|
||||||
|
|
||||||
|
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_cluster(
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
address: Optional[str] = None,
|
||||||
|
) -> Tuple[str, List[List[DeviceID]]]:
|
||||||
|
if not parallel_config.use_ray:
|
||||||
|
# Initialize cluster locally.
|
||||||
|
port = random.randint(10000, 20000)
|
||||||
|
# We need to setup the distributed init method to make sure
|
||||||
|
# the distributed megatron code (e.g., get world size) works correctly.
|
||||||
|
distributed_init_method = f"tcp://localhost:{port}"
|
||||||
|
all_stage_devices = [[(0, None, 0)]]
|
||||||
|
return distributed_init_method, all_stage_devices
|
||||||
|
|
||||||
|
if ray is None:
|
||||||
|
raise ImportError(
|
||||||
|
"Ray is not installed. Please install Ray to use distributed "
|
||||||
|
"serving.")
|
||||||
|
# Connect to a ray cluster.
|
||||||
|
ray.init(address=address)
|
||||||
|
|
||||||
|
# Assume we have a uniform cluster that each node has the same number of
|
||||||
|
# GPUs for now.
|
||||||
|
valid_node_resources = []
|
||||||
|
num_devices_per_node = None
|
||||||
|
for node in ray.nodes():
|
||||||
|
if (not node['Alive']) or node['Resources']['GPU'] <= 0:
|
||||||
|
continue
|
||||||
|
if num_devices_per_node is None:
|
||||||
|
num_devices_per_node = node['Resources']['GPU']
|
||||||
|
else:
|
||||||
|
assert num_devices_per_node == node['Resources']['GPU'], (
|
||||||
|
"The number of GPUs per node is not uniform.")
|
||||||
|
for key in node['Resources']:
|
||||||
|
if key.startswith('node:'):
|
||||||
|
valid_node_resources.append(key)
|
||||||
|
|
||||||
|
# Verify the parallel config.
|
||||||
|
num_nodes = len(valid_node_resources)
|
||||||
|
if parallel_config.world_size > num_nodes * num_devices_per_node:
|
||||||
|
raise ValueError(
|
||||||
|
"The number of required GPUs exceeds the total number of "
|
||||||
|
"available GPUs.")
|
||||||
|
if parallel_config.tensor_parallel_size >= num_devices_per_node:
|
||||||
|
if parallel_config.tensor_parallel_size % num_devices_per_node != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The number of tensor parallelism is not divisible by the "
|
||||||
|
"number of GPUs per node.")
|
||||||
|
else:
|
||||||
|
if num_devices_per_node % parallel_config.tensor_parallel_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The number of GPUs per node is not divisible by the number "
|
||||||
|
"of tensor parallelism.")
|
||||||
|
|
||||||
|
# Assign GPUs to pipeline stages.
|
||||||
|
rank = 0
|
||||||
|
current_node_id = 0
|
||||||
|
current_device_id = 0
|
||||||
|
distributed_init_method = None
|
||||||
|
all_stage_devices = []
|
||||||
|
|
||||||
|
for _ in range(parallel_config.pipeline_parallel_size):
|
||||||
|
stage_devices = []
|
||||||
|
for _ in range(parallel_config.tensor_parallel_size):
|
||||||
|
node_resource = valid_node_resources[current_node_id]
|
||||||
|
stage_devices.append((rank, node_resource, current_device_id))
|
||||||
|
if distributed_init_method is None:
|
||||||
|
ip = node_resource.split("node:")[-1]
|
||||||
|
port = random.randint(10000, 20000)
|
||||||
|
distributed_init_method = f"tcp://{ip}:{port}"
|
||||||
|
rank += 1
|
||||||
|
current_device_id += 1
|
||||||
|
if current_device_id >= num_devices_per_node:
|
||||||
|
current_node_id += 1
|
||||||
|
current_device_id = 0
|
||||||
|
all_stage_devices.append(stage_devices)
|
||||||
|
|
||||||
|
return distributed_init_method, all_stage_devices
|
@ -3,7 +3,6 @@ from typing import Union
|
|||||||
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
|
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerFast)
|
PreTrainedTokenizerFast)
|
||||||
|
|
||||||
|
|
||||||
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
|
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
|
||||||
# LLaMA fast tokenizer has a bug related to protobuf.
|
# LLaMA fast tokenizer has a bug related to protobuf.
|
||||||
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
|
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
|
@ -4,6 +4,7 @@ from typing import Dict, List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from cacheflow import cache_ops
|
from cacheflow import cache_ops
|
||||||
|
from cacheflow.config import CacheConfig, ModelConfig, ParallelConfig
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@ -18,27 +19,22 @@ class CacheEngine:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
worker_id: int,
|
cache_config: CacheConfig,
|
||||||
num_layers: int,
|
model_config: ModelConfig,
|
||||||
num_heads: int,
|
parallel_config: ParallelConfig,
|
||||||
head_size: int,
|
|
||||||
block_size: int,
|
|
||||||
num_gpu_blocks: int,
|
|
||||||
num_cpu_blocks: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if head_size % 16 != 0:
|
self.cache_config = cache_config
|
||||||
raise ValueError(
|
self.model_config = model_config
|
||||||
f'head_size ({head_size}) must be a multiple of 16.')
|
self.parallel_config = parallel_config
|
||||||
|
|
||||||
self.worker_id = worker_id
|
self.head_size = model_config.get_head_size()
|
||||||
self.num_layers = num_layers
|
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||||
self.num_heads = num_heads
|
self.num_heads = model_config.get_num_heads(parallel_config)
|
||||||
self.head_size = head_size
|
self.dtype = model_config.dtype
|
||||||
self.block_size = block_size
|
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
self.block_size = cache_config.block_size
|
||||||
self.num_cpu_blocks = num_cpu_blocks
|
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
||||||
self.dtype = dtype
|
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
||||||
|
|
||||||
# Initialize the cache.
|
# Initialize the cache.
|
||||||
self.gpu_cache = self.allocate_gpu_cache()
|
self.gpu_cache = self.allocate_gpu_cache()
|
||||||
@ -48,7 +44,7 @@ class CacheEngine:
|
|||||||
self.cache_stream = torch.cuda.Stream()
|
self.cache_stream = torch.cuda.Stream()
|
||||||
assert self.cache_stream != torch.cuda.current_stream()
|
assert self.cache_stream != torch.cuda.current_stream()
|
||||||
# Initialize the events for stream synchronization.
|
# Initialize the events for stream synchronization.
|
||||||
self.events = [torch.cuda.Event() for _ in range(num_layers)]
|
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
|
||||||
|
|
||||||
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
|
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
|
||||||
element_size = torch.tensor([], dtype=self.dtype).element_size()
|
element_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||||
@ -133,3 +129,23 @@ class CacheEngine:
|
|||||||
value_caches = [value_cache for _, value_cache in self.gpu_cache]
|
value_caches = [value_cache for _, value_cache in self.gpu_cache]
|
||||||
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
|
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
|
||||||
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
|
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_cache_block_size(
|
||||||
|
block_size: int,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
) -> int:
|
||||||
|
head_size = model_config.get_head_size()
|
||||||
|
num_heads = model_config.get_num_heads(parallel_config)
|
||||||
|
num_layers = model_config.get_num_layers(parallel_config)
|
||||||
|
|
||||||
|
key_cache_block = block_size * num_heads * head_size
|
||||||
|
value_cache_block = key_cache_block
|
||||||
|
total = num_layers * (key_cache_block + value_cache_block)
|
||||||
|
dtype_size = _get_dtype_size(model_config.dtype)
|
||||||
|
return dtype_size * total
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dtype_size(dtype: torch.dtype) -> int:
|
||||||
|
return torch.tensor([], dtype=dtype).element_size()
|
||||||
|
@ -1,130 +0,0 @@
|
|||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
try:
|
|
||||||
import ray
|
|
||||||
except ImportError:
|
|
||||||
ray = None
|
|
||||||
|
|
||||||
from cacheflow.core.scheduler import Scheduler
|
|
||||||
from cacheflow.worker.worker import Worker
|
|
||||||
|
|
||||||
|
|
||||||
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id
|
|
||||||
|
|
||||||
|
|
||||||
class Controller:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
stage_id: int,
|
|
||||||
stage_devices: List[DeviceID],
|
|
||||||
world_size: int,
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
pipeline_parallel_size: int,
|
|
||||||
distributed_init_method: str,
|
|
||||||
model_name: str,
|
|
||||||
dtype: str,
|
|
||||||
seed: int,
|
|
||||||
cache_dir: Optional[str],
|
|
||||||
use_dummy_weights: bool,
|
|
||||||
use_np_cache: bool,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
max_num_sequences: int,
|
|
||||||
use_ray: bool,
|
|
||||||
) -> None:
|
|
||||||
self.stage_id = stage_id
|
|
||||||
self.stage_devices = stage_devices
|
|
||||||
self.model_name = model_name
|
|
||||||
self.use_ray = use_ray
|
|
||||||
|
|
||||||
# Which pipeline stage is this node assigned to?
|
|
||||||
self.is_first_stage = stage_id == 0
|
|
||||||
self.is_last_stage = False
|
|
||||||
|
|
||||||
self.workers: List[Worker] = []
|
|
||||||
for rank, node_resource, device_id in stage_devices:
|
|
||||||
if self.use_ray:
|
|
||||||
worker_cls = ray.remote(num_cpus=0,
|
|
||||||
num_gpus=1,
|
|
||||||
resources={node_resource: 1e-5})(Worker).remote
|
|
||||||
else:
|
|
||||||
worker_cls = Worker
|
|
||||||
worker = worker_cls(
|
|
||||||
model_name=model_name,
|
|
||||||
dtype=dtype,
|
|
||||||
seed=seed,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
pipeline_parallel_size=pipeline_parallel_size,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
use_dummy_weights=use_dummy_weights,
|
|
||||||
use_np_cache=use_np_cache,
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
|
||||||
max_num_sequences=max_num_sequences,
|
|
||||||
)
|
|
||||||
self.workers.append(worker)
|
|
||||||
|
|
||||||
def get_num_available_blocks(self, block_size: int, cpu_swap_space: int,
|
|
||||||
gpu_memory_utilization: float) -> List[Tuple[int, int]]:
|
|
||||||
all_worker_results = []
|
|
||||||
for worker in self.workers:
|
|
||||||
executor = worker.get_num_available_blocks
|
|
||||||
if self.use_ray:
|
|
||||||
executor = executor.remote
|
|
||||||
|
|
||||||
result = executor(
|
|
||||||
block_size,
|
|
||||||
cpu_swap_space,
|
|
||||||
gpu_memory_utilization,
|
|
||||||
)
|
|
||||||
all_worker_results.append(result)
|
|
||||||
if self.use_ray:
|
|
||||||
all_worker_results = ray.get(all_worker_results)
|
|
||||||
return all_worker_results
|
|
||||||
|
|
||||||
def init_cache_engine(self, block_size: int, num_gpu_blocks: int,
|
|
||||||
num_cpu_blocks: int):
|
|
||||||
all_worker_futures = []
|
|
||||||
for worker in self.workers:
|
|
||||||
executor = worker.init_cache_engine
|
|
||||||
if self.use_ray:
|
|
||||||
executor = executor.remote
|
|
||||||
future = executor(
|
|
||||||
block_size,
|
|
||||||
num_gpu_blocks,
|
|
||||||
num_cpu_blocks,
|
|
||||||
)
|
|
||||||
all_worker_futures.append(future)
|
|
||||||
if self.use_ray:
|
|
||||||
ray.get(all_worker_futures)
|
|
||||||
|
|
||||||
def set_next(
|
|
||||||
self,
|
|
||||||
next_node: Union['Controller', 'Scheduler'],
|
|
||||||
) -> None:
|
|
||||||
self.next_node = next_node
|
|
||||||
self.is_last_stage = isinstance(next_node, Scheduler)
|
|
||||||
|
|
||||||
def execute_stage(self, *args, **kwargs) -> None:
|
|
||||||
all_outputs = []
|
|
||||||
for worker in self.workers:
|
|
||||||
executor = (worker.execute_stage.remote
|
|
||||||
if self.use_ray else worker.execute_stage)
|
|
||||||
output = executor(*args, **kwargs)
|
|
||||||
all_outputs.append(output)
|
|
||||||
|
|
||||||
if self.use_ray:
|
|
||||||
all_outputs = ray.get(all_outputs)
|
|
||||||
|
|
||||||
# Make sure all workers have the same results.
|
|
||||||
output = all_outputs[0]
|
|
||||||
for other_output in all_outputs[1:]:
|
|
||||||
assert output == other_output
|
|
||||||
|
|
||||||
if self.is_last_stage:
|
|
||||||
self.next_node.post_step(output)
|
|
||||||
else:
|
|
||||||
# TODO: Support pipeline parallelism.
|
|
||||||
assert False
|
|
@ -1,14 +1,13 @@
|
|||||||
"""A GPU worker class."""
|
"""A GPU worker class."""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from cacheflow.model_executor import (get_model, get_cache_block_size,
|
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
InputMetadata, set_random_seed)
|
SchedulerConfig)
|
||||||
|
from cacheflow.model_executor import get_model, InputMetadata, set_random_seed
|
||||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||||
initialize_model_parallel,
|
initialize_model_parallel, initialize_all_reduce_launcher)
|
||||||
initialize_all_reduce_launcher,
|
|
||||||
get_tensor_model_parallel_world_size)
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
|
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
|
||||||
SequenceOutputs)
|
SequenceOutputs)
|
||||||
@ -26,59 +25,46 @@ class Worker:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_config: ModelConfig,
|
||||||
dtype: str,
|
parallel_config: ParallelConfig,
|
||||||
seed: int,
|
scheduler_config: SchedulerConfig,
|
||||||
distributed_init_method: str,
|
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
distributed_init_method: str,
|
||||||
cache_dir: Optional[str],
|
|
||||||
use_dummy_weights: bool,
|
|
||||||
use_np_cache: bool,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
max_num_sequences: int,
|
|
||||||
tensor_parallel_size: int = 1,
|
|
||||||
pipeline_parallel_size: int = 1,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.init_distributed_environment(distributed_init_method,
|
self.model_config = model_config
|
||||||
rank,
|
self.parallel_config = parallel_config
|
||||||
world_size,
|
self.scheduler_config = scheduler_config
|
||||||
tensor_parallel_size,
|
self.rank = rank
|
||||||
pipeline_parallel_size)
|
self.distributed_init_method = distributed_init_method
|
||||||
self.worker_id = rank
|
|
||||||
self.seed = seed
|
# Initialize the distributed environment.
|
||||||
set_random_seed(self.seed)
|
_init_distributed_environment(parallel_config, rank,
|
||||||
|
distributed_init_method)
|
||||||
|
|
||||||
# Initialize the model.
|
# Initialize the model.
|
||||||
self.model, self.dtype = get_model(
|
set_random_seed(self.model_config.seed)
|
||||||
model_name, dtype=dtype, cache_dir=cache_dir,
|
self.model = get_model(model_config)
|
||||||
use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache)
|
|
||||||
tensor_model_parallel_world_size = (
|
|
||||||
get_tensor_model_parallel_world_size())
|
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
|
||||||
initialize_all_reduce_launcher(
|
initialize_all_reduce_launcher(
|
||||||
self.max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
|
self.scheduler_config.max_num_batched_tokens,
|
||||||
self.max_num_sequences = max_num_sequences
|
self.model_config.get_hidden_size(),
|
||||||
self.num_layers = self.model.config.num_hidden_layers
|
self.model_config.dtype,
|
||||||
assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
|
)
|
||||||
self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
|
|
||||||
self.head_size = self.model.config.hidden_size // (self.num_heads * tensor_model_parallel_world_size)
|
|
||||||
|
|
||||||
# We reset the seed after initializing the model to ensure that
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# the random state is not affected by the model initialization.
|
|
||||||
set_random_seed(seed)
|
|
||||||
|
|
||||||
# Uninitialized cache engine. Will be initialized with
|
|
||||||
# self.init_cache_engine().
|
# self.init_cache_engine().
|
||||||
|
self.cache_config = None
|
||||||
self.block_size = None
|
self.block_size = None
|
||||||
self.cache_engine = None
|
self.cache_engine = None
|
||||||
self.cache_events = None
|
self.cache_events = None
|
||||||
self.gpu_cache = None
|
self.gpu_cache = None
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_num_available_blocks(
|
def profile_num_available_blocks(
|
||||||
self, block_size: int, cpu_swap_space: int,
|
self,
|
||||||
gpu_memory_utilization: float) -> Tuple[int, int]:
|
block_size: int,
|
||||||
|
gpu_memory_utilization: float,
|
||||||
|
cpu_swap_space: int,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
# Profile the memory usage of the model and get the maximum number of
|
# Profile the memory usage of the model and get the maximum number of
|
||||||
# cache blocks that can be allocated with the remaining free memory.
|
# cache blocks that can be allocated with the remaining free memory.
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -90,14 +76,15 @@ class Worker:
|
|||||||
# Enable top-k sampling to reflect the accurate memory usage.
|
# Enable top-k sampling to reflect the accurate memory usage.
|
||||||
sampling_params = SamplingParams(top_p=0.99,
|
sampling_params = SamplingParams(top_p=0.99,
|
||||||
top_k=self.model.config.vocab_size - 1)
|
top_k=self.model.config.vocab_size - 1)
|
||||||
|
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||||
seqs = []
|
seqs = []
|
||||||
for group_id in range(self.max_num_sequences):
|
for group_id in range(max_num_seqs):
|
||||||
seq_len = (self.max_num_batched_tokens // self.max_num_sequences +
|
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||||
(group_id < self.max_num_batched_tokens %
|
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||||
self.max_num_sequences))
|
|
||||||
seq_data = SequenceData([0] * seq_len)
|
seq_data = SequenceData([0] * seq_len)
|
||||||
seq = SequenceGroupMetadata(
|
seq = SequenceGroupMetadata(
|
||||||
group_id=group_id,
|
request_id=str(group_id),
|
||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
seq_data={group_id: seq_data},
|
seq_data={group_id: seq_data},
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
@ -105,13 +92,14 @@ class Worker:
|
|||||||
)
|
)
|
||||||
seqs.append(seq)
|
seqs.append(seq)
|
||||||
|
|
||||||
input_tokens, input_positions, input_metadata = self.prepare_inputs(seqs)
|
input_tokens, input_positions, input_metadata = self._prepare_inputs(seqs)
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
self.model(
|
self.model(
|
||||||
input_ids=input_tokens,
|
input_ids=input_tokens,
|
||||||
positions=input_positions,
|
positions=input_positions,
|
||||||
kv_caches=[(None, None)] * self.num_layers,
|
kv_caches=[(None, None)] * num_layers,
|
||||||
input_metadata=input_metadata,
|
input_metadata=input_metadata,
|
||||||
cache_events=None,
|
cache_events=None,
|
||||||
)
|
)
|
||||||
@ -121,53 +109,27 @@ class Worker:
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
peak_memory = torch.cuda.max_memory_allocated()
|
peak_memory = torch.cuda.max_memory_allocated()
|
||||||
total_gpu_memory = get_gpu_memory()
|
total_gpu_memory = get_gpu_memory()
|
||||||
cache_block_size = get_cache_block_size(block_size, self.num_heads,
|
cache_block_size = CacheEngine.get_cache_block_size(
|
||||||
self.head_size, self.num_layers,
|
block_size, self.model_config, self.parallel_config)
|
||||||
self.dtype)
|
|
||||||
num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization
|
num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization
|
||||||
- peak_memory) // cache_block_size)
|
- peak_memory) // cache_block_size)
|
||||||
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
|
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
# Reset the seed to ensure that the model output is not affected by
|
|
||||||
# the profiling.
|
# Reset the seed to ensure that the random state is not affected by
|
||||||
set_random_seed(self.seed)
|
# the model initialization and profiling.
|
||||||
|
set_random_seed(self.model_config.seed)
|
||||||
return num_gpu_blocks, num_cpu_blocks
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
def init_cache_engine(self, block_size: int, num_gpu_blocks: int,
|
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
||||||
num_cpu_blocks: int):
|
self.cache_config = cache_config
|
||||||
self.block_size = block_size
|
self.block_size = cache_config.block_size
|
||||||
self.cache_engine = CacheEngine(
|
self.cache_engine = CacheEngine(
|
||||||
worker_id=self.worker_id,
|
self.cache_config, self.model_config, self.parallel_config)
|
||||||
num_layers=self.num_layers,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
head_size=self.head_size,
|
|
||||||
block_size=self.block_size,
|
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
|
||||||
num_cpu_blocks=num_cpu_blocks,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
self.cache_events = self.cache_engine.events
|
self.cache_events = self.cache_engine.events
|
||||||
self.gpu_cache = self.cache_engine.gpu_cache
|
self.gpu_cache = self.cache_engine.gpu_cache
|
||||||
|
|
||||||
def init_distributed_environment(self,
|
def _prepare_inputs(
|
||||||
distributed_init_method: str,
|
|
||||||
rank: int,
|
|
||||||
world_size: int,
|
|
||||||
tensor_parallel_size: int = 1,
|
|
||||||
pipeline_parallel_size: int = 1) -> None:
|
|
||||||
"""Initialize the distributed environment."""
|
|
||||||
torch.distributed.init_process_group(
|
|
||||||
backend='nccl',
|
|
||||||
init_method=distributed_init_method,
|
|
||||||
world_size=world_size,
|
|
||||||
rank=rank,
|
|
||||||
)
|
|
||||||
# A small all_reduce for warmup.
|
|
||||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
|
||||||
initialize_model_parallel(tensor_parallel_size,
|
|
||||||
pipeline_parallel_size)
|
|
||||||
|
|
||||||
def prepare_inputs(
|
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
|
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
|
||||||
@ -284,7 +246,7 @@ class Worker:
|
|||||||
return tokens_tensor, positions_tensor, input_metadata
|
return tokens_tensor, positions_tensor, input_metadata
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_stage(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
blocks_to_swap_in: Dict[int, int],
|
blocks_to_swap_in: Dict[int, int],
|
||||||
@ -316,7 +278,7 @@ class Worker:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Prepare input tensors.
|
# Prepare input tensors.
|
||||||
input_tokens, input_positions, input_metadata = self.prepare_inputs(
|
input_tokens, input_positions, input_metadata = self._prepare_inputs(
|
||||||
seq_group_metadata_list)
|
seq_group_metadata_list)
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
@ -330,6 +292,24 @@ class Worker:
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _init_distributed_environment(
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: str,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the distributed environment."""
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
world_size=parallel_config.world_size,
|
||||||
|
rank=rank,
|
||||||
|
init_method=distributed_init_method,
|
||||||
|
)
|
||||||
|
# A small all_reduce for warmup.
|
||||||
|
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||||
|
initialize_model_parallel(parallel_config.tensor_parallel_size,
|
||||||
|
parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
|
|
||||||
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
|
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
|
||||||
return x + [0] * ((-len(x)) % multiple_of)
|
return x + [0] * ((-len(x)) % multiple_of)
|
||||||
|
|
||||||
|
44
examples/simple_server.py
Normal file
44
examples/simple_server.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import argparse
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from cacheflow import (add_server_arguments, initialize_server_from_args,
|
||||||
|
SamplingParams)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
# Initialize the server.
|
||||||
|
server = initialize_server_from_args(args)
|
||||||
|
|
||||||
|
# Test the following prompts.
|
||||||
|
test_prompts = [
|
||||||
|
("A robot may not injure a human being", SamplingParams()),
|
||||||
|
("To be or not to be,",
|
||||||
|
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
||||||
|
("What is the meaning of life?",
|
||||||
|
SamplingParams(n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1)),
|
||||||
|
("It is only with the heart that one can see rightly",
|
||||||
|
SamplingParams(n=3, use_beam_search=True, temperature=0.0)),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Run the server.
|
||||||
|
while True:
|
||||||
|
# To test iteration-level scheduling, we add one request at each step.
|
||||||
|
if test_prompts:
|
||||||
|
prompt, sampling_params = test_prompts.pop(0)
|
||||||
|
request_id = str(uuid.uuid4().hex[:8])
|
||||||
|
server.add_request(request_id, prompt, sampling_params)
|
||||||
|
|
||||||
|
request_outputs = server.step()
|
||||||
|
for request_output in request_outputs:
|
||||||
|
if request_output.done:
|
||||||
|
print(request_output)
|
||||||
|
|
||||||
|
if not (server.has_unfinished_requests() or test_prompts):
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Simple CacheFlow server.')
|
||||||
|
parser = add_server_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
@ -1,38 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
from cacheflow.core.server import (
|
|
||||||
add_server_arguments, process_server_arguments,
|
|
||||||
init_local_server_and_frontend_with_arguments)
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
|
||||||
server, frontend = init_local_server_and_frontend_with_arguments(args)
|
|
||||||
# Test the following inputs.
|
|
||||||
test_inputs = [
|
|
||||||
("A robot may not injure a human being", {}), # Use default parameters.
|
|
||||||
("To be or not to be,", {"temperature": 0.8, "top_k": 5, "presence_penalty": 0.2}),
|
|
||||||
("What is the meaning of life?", {"n": 2, "temperature": 0.8, "top_p": 0.95, "frequency_penalty": 0.1}),
|
|
||||||
("It is only with the heart that one can see rightly", {"n": 3, "use_beam_search": True, "temperature": 0.0}),
|
|
||||||
]
|
|
||||||
while True:
|
|
||||||
if test_inputs:
|
|
||||||
text, sampling_params_dict = test_inputs.pop(0)
|
|
||||||
sampling_params = SamplingParams(**sampling_params_dict)
|
|
||||||
sampling_params = frontend.add_eos_token(sampling_params)
|
|
||||||
frontend.query(text, sampling_params)
|
|
||||||
server.add_sequence_groups(frontend.get_inputs())
|
|
||||||
updated_seq_groups = server.step()
|
|
||||||
for seq_group in updated_seq_groups:
|
|
||||||
if seq_group.is_finished():
|
|
||||||
frontend.print_response(seq_group)
|
|
||||||
if not (server.has_unfinished_requests() or test_inputs):
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
|
|
||||||
parser = add_server_arguments(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
args = process_server_arguments(args)
|
|
||||||
main(args)
|
|
Loading…
x
Reference in New Issue
Block a user