Use runtime profiling to replace manual memory analyzers (#81)

This commit is contained in:
Zhuohan Li 2023-05-19 11:35:44 -06:00 committed by GitHub
parent 825d8892b5
commit f756799b84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 211 additions and 478 deletions

View File

@ -6,15 +6,14 @@ try:
import ray
except ImportError:
ray = None
import numpy as np
import torch
from cacheflow.core.scheduler import Scheduler
from cacheflow.frontend.simple_frontend import SimpleFrontend
from cacheflow.logger import init_logger
from cacheflow.model_executor import get_memory_analyzer
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroup
from cacheflow.utils import get_gpu_memory, get_cpu_memory
from cacheflow.worker.controller import Controller, DeviceID
logger = init_logger(__name__)
@ -34,14 +33,13 @@ class Server:
dtype: str,
seed: int,
swap_space: int,
gpu_memory_utilization: float,
max_num_batched_tokens: int,
max_num_sequences: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
gpu_memory: int,
cpu_memory: int,
use_ray: bool,
log_stats: bool,
):
@ -63,19 +61,6 @@ class Server:
assert self.world_size == 1, (
"Only support single GPU without Ray.")
self.memory_analyzer = get_memory_analyzer(
model_name=model,
block_size=block_size,
dtype=dtype,
gpu_memory=gpu_memory,
cpu_memory=cpu_memory,
tensor_parallel_size=tensor_parallel_size,
)
self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
max_num_batched_tokens=max_num_batched_tokens)
self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
swap_space_gib=swap_space)
# Create a controller for each pipeline stage.
self.controllers: List[Controller] = []
for i in range(pipeline_parallel_size):
@ -87,19 +72,35 @@ class Server:
tensor_parallel_size=tensor_parallel_size,
distributed_init_method=distributed_init_method,
model_name=model,
block_size=block_size,
num_gpu_blocks=self.num_gpu_blocks,
num_cpu_blocks=self.num_cpu_blocks,
dtype=dtype,
seed=seed,
cache_dir=cache_dir,
use_dummy_weights=use_dummy_weights,
use_np_cache=use_np_cache,
max_num_batched_tokens=max_num_batched_tokens,
max_num_sequences=max_num_sequences,
use_ray=use_ray,
)
self.controllers.append(controller)
# Initialize cache engine.
all_worker_num_available_blocks = []
for controller in self.controllers:
all_worker_num_available_blocks.extend(
controller.get_num_available_blocks(
block_size, swap_space, gpu_memory_utilization)
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
self.num_gpu_blocks = np.min([b[0] for b in all_worker_num_available_blocks])
self.num_cpu_blocks = np.min([b[1] for b in all_worker_num_available_blocks])
logger.info(f'# GPU blocks: {self.num_gpu_blocks}, '
f'# CPU blocks: {self.num_cpu_blocks}')
for controller in self.controllers:
controller.init_cache_engine(block_size, self.num_gpu_blocks,
self.num_cpu_blocks)
# Create a scheduler.
self.scheduler = Scheduler(
controllers=self.controllers,
@ -214,7 +215,11 @@ def initialize_cluster(
all_stage_devices)
_GiB = 1 << 30
def add_server_arguments(parser: argparse.ArgumentParser):
"""Shared arguments for CacheFlow servers."""
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--cache-dir', type=str, default=None,
@ -238,6 +243,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
parser.add_argument('--log-stats', action='store_true', help='log system statistics')
@ -245,8 +251,11 @@ def add_server_arguments(parser: argparse.ArgumentParser):
def process_server_arguments(args: argparse.Namespace):
"""Post process the parsed arguments."""
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
args.use_ray = True
args.swap_space = args.swap_space * _GiB
args.max_num_sequences = min(args.max_num_sequences, args.max_num_batched_tokens)
return args
@ -274,14 +283,13 @@ def init_local_server_and_frontend_with_arguments(args: argparse.Namespace):
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
gpu_memory_utilization=args.gpu_memory_utilization,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=args.use_ray,
log_stats=args.log_stats,
)

View File

@ -15,7 +15,7 @@ from cacheflow.core.server import (Server, add_server_arguments,
from cacheflow.frontend.utils import get_tokenizer
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
from cacheflow.utils import Counter
from cacheflow.worker.controller import DeviceID
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
@ -34,6 +34,7 @@ class FastAPIServer:
dtype: str,
seed: int,
swap_space: int,
gpu_memory_utilization: float,
max_num_batched_tokens: int,
max_num_sequences: int,
num_nodes: int,
@ -41,6 +42,7 @@ class FastAPIServer:
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
server_use_ray: bool,
log_stats: bool,
):
self.block_size = block_size
@ -62,15 +64,15 @@ class FastAPIServer:
dtype=dtype,
seed=seed,
swap_space=swap_space,
gpu_memory_utilization=gpu_memory_utilization,
max_num_batched_tokens=max_num_batched_tokens,
max_num_sequences=max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=server_use_ray,
log_stats=log_stats,
)
self.running_seq_groups: Dict[int, SequenceGroup] = {}
@ -182,6 +184,7 @@ if __name__ == "__main__":
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
gpu_memory_utilization=args.gpu_memory_utilization,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
@ -189,6 +192,7 @@ if __name__ == "__main__":
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
server_use_ray=args.use_ray,
log_stats=args.log_stats,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View File

@ -1,11 +1,12 @@
from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.model_loader import get_model, get_memory_analyzer
from cacheflow.model_executor.utils import set_random_seed
from cacheflow.model_executor.model_loader import get_model
from cacheflow.model_executor.utils import (set_random_seed,
get_cache_block_size)
__all__ = [
"InputMetadata",
"get_cache_block_size",
"get_model",
"get_memory_analyzer",
"set_random_seed",
]

View File

@ -11,6 +11,8 @@ from cacheflow import pos_encoding_ops
from cacheflow.model_executor.input_metadata import InputMetadata
_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]
class GPTCacheFlowAttention(nn.Module):
"""GPT-style multi-head attention.
@ -39,11 +41,19 @@ class GPTCacheFlowAttention(nn.Module):
5. Output a flattened 1D tensor.
"""
def __init__(self, scale: float) -> None:
def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp()
if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f'head_size ({self.head_size}) is not supported by '
'the single_query_cached_kv_attention kernel. '
'Use one of the following head sizes: '
f'{_SUPPORTED_HEAD_SIZES}.')
def multi_query_kv_attention(
self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
@ -74,14 +84,6 @@ class GPTCacheFlowAttention(nn.Module):
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
) -> None:
head_size = value_cache.shape[2]
supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
if head_size not in supported_head_sizes:
raise ValueError(f'head_size ({head_size}) is not supported by '
'the single_query_cached_kv_attention kernel. '
'Use one of the following head sizes: '
f'{supported_head_sizes}.')
block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
@ -100,8 +102,8 @@ class GPTCacheFlowAttention(nn.Module):
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
@ -109,11 +111,9 @@ class GPTCacheFlowAttention(nn.Module):
# tensor of shape [num_tokens, 3 * num_heads * head_size].
# Reshape the query, key, and value tensors.
num_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size)
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
# Pre-allocate the output tensor.
output = torch.empty_like(query)
@ -134,8 +134,11 @@ class GPTCacheFlowAttention(nn.Module):
cache_event.wait()
# Reshape the keys and values and store them in the cache.
# When key_cache and value_cache are not provided, the new key
# and value vectors will not be cached.
num_valid_tokens = input_metadata.num_valid_tokens
if num_valid_tokens > 0:
if (num_valid_tokens > 0 and key_cache is not None
and value_cache is not None):
# The stride is 3 because the key and value are sliced from qkv.
cache_ops.reshape_and_cache(
key[:num_valid_tokens],
@ -146,6 +149,10 @@ class GPTCacheFlowAttention(nn.Module):
)
if input_metadata.num_generation_tokens > 0:
assert key_cache is not None and value_cache is not None, (
"key_cache and value_cache must be provided when "
"generating tokens."
)
# Compute the attention op for generation tokens.
self.single_query_cached_kv_attention(
output[num_prompt_tokens:num_valid_tokens],
@ -156,7 +163,7 @@ class GPTCacheFlowAttention(nn.Module):
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
return output.view(-1, num_heads * head_size)
return output.view(-1, self.num_heads * self.head_size)
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
@ -164,12 +171,14 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
rotary_dim: int,
max_position: int = 8192,
base: int = 10000,
) -> None:
super().__init__(scale)
super().__init__(num_heads, head_size, scale)
# Create the cos and sin cache.
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
@ -199,12 +208,11 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
head_size = value_cache.shape[2]
pos_encoding_ops.rotary_embedding_neox(
positions,
query,
key,
head_size,
self.head_size,
self.cos_sin_cache,
)
return super().forward(

View File

@ -74,7 +74,7 @@ class Sampler(nn.Module):
# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == probs.shape[0]
if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks):
if any(p < 1.0 for p in top_ps) or any(k != self.vocab_size for k in top_ks):
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
# Sample the next tokens.

View File

@ -1,370 +0,0 @@
import torch
from transformers import AutoConfig
from cacheflow.logger import init_logger
from cacheflow.model_executor.utils import get_dtype_size
logger = init_logger(__name__)
_GiB = 1 << 30
class CacheFlowMemoryAnalyzer:
def get_max_num_gpu_blocks(
self,
max_num_batched_tokens: int,
memory_utilization: float,
) -> int:
raise NotImplementedError()
def get_workspace_size(self) -> int:
return 1 * _GiB
def get_cache_block_size(self) -> int:
raise NotImplementedError()
def get_max_num_cpu_blocks(
self,
swap_space_gib: int,
) -> int:
swap_space = swap_space_gib * _GiB
cpu_memory = self.cpu_memory
if swap_space > 0.8 * cpu_memory:
raise ValueError(f'The swap space ({swap_space_gib:.2f} GiB) '
'takes more than 80% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
'Please check the swap space size.')
if swap_space > 0.5 * cpu_memory:
logger.info(f'WARNING: The swap space ({swap_space_gib:.2f} GiB) '
'takes more than 50% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
'This may slow the system performance.')
max_num_blocks = swap_space // self.get_cache_block_size()
return max_num_blocks
def get_param_size(self) -> int:
raise NotImplementedError()
def get_max_act_size(self, max_num_batched_tokens: int) -> int:
raise NotImplementedError()
def get_cache_block_size(self) -> int:
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
value_cache_block = key_cache_block
total = self.num_layers * (key_cache_block + value_cache_block)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def get_max_num_gpu_blocks(
self,
max_num_batched_tokens: int,
memory_utilization: float = 0.95,
) -> int:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
usable_memory = int(memory_utilization * self.gpu_memory)
param_size = self.get_param_size()
act_size = self.get_max_act_size(max_num_batched_tokens)
workspace_size = self.get_workspace_size()
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
if max_cache_size <= 0:
raise RuntimeError('Not enough GPU memory.')
max_num_blocks = max_cache_size // self.get_cache_block_size()
return max_num_blocks
class GPT2MemoryAnalyzer(CacheFlowMemoryAnalyzer):
def __init__(
self,
model_name: str,
block_size: int,
dtype: torch.dtype,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int,
) -> None:
self.model_name = model_name
self.block_size = block_size
self.dtype = dtype
self.gpu_memory = gpu_memory
self.cpu_memory = cpu_memory
self.tensor_parallel_size = tensor_parallel_size
config = AutoConfig.from_pretrained(model_name)
self.num_layers = config.num_hidden_layers
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = config.hidden_size // self.num_heads
self.ffn_size = config.n_inner if config.n_inner is not None else 4 * self.hidden_size
self.vocab_size = config.vocab_size
self.max_position = config.max_position_embeddings
def get_param_size(self) -> int:
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
position_embedding = self.max_position * self.hidden_size
ln1 = 2 * self.hidden_size
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
mha = ln1 + q + k + v + out
ln2 = 2 * self.hidden_size
ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
ffn = ln2 + ffn1 + ffn2
total = (word_embedding + position_embedding +
self.num_layers * (mha + ffn))
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def get_max_act_size(
self,
max_num_batched_tokens: int,
) -> int:
# NOTE: We approxmiately calculate the maximum activation size by
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
# Double the activation size for input and output.
max_act = 2 * (max(qkv, ffn) + residual)
# Size of output logits.
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
max_act = max(max_act, output_logits)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * max_act
class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
def __init__(
self,
model_name: str,
block_size: int,
dtype: torch.dtype,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int,
) -> None:
self.model_name = model_name
self.block_size = block_size
self.dtype = dtype
self.gpu_memory = gpu_memory
self.cpu_memory = cpu_memory
self.tensor_parallel_size = tensor_parallel_size
config = AutoConfig.from_pretrained(model_name)
self.num_layers = config.num_hidden_layers
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = config.hidden_size // self.num_heads
self.ffn_size = config.ffn_dim
self.embedding_size = config.word_embed_proj_dim
self.vocab_size = config.vocab_size
self.max_position = config.max_position_embeddings
def get_param_size(self) -> int:
word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size
if self.embedding_size != self.hidden_size:
# Project in/out.
word_embedding += 2 * self.embedding_size * self.hidden_size
position_embedding = self.max_position * self.hidden_size
ln1 = 2 * self.hidden_size
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
mha = ln1 + q + k + v + out
ln2 = 2 * self.hidden_size
ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
ffn = ln2 + ffn1 + ffn2
total = (word_embedding + position_embedding +
self.num_layers * (mha + ffn))
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def get_max_act_size(
self,
max_num_batched_tokens: int,
) -> int:
# NOTE: We approxmiately calculate the maximum activation size by
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
# Double the activation size for input and output.
max_act = 2 * (max(qkv, ffn) + residual)
# Size of output logits.
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
max_act = max(max_act, output_logits)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * max_act
class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
def __init__(
self,
model_name: str,
block_size: int,
dtype: torch.dtype,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int,
) -> None:
self.model_name = model_name
self.block_size = block_size
self.dtype = dtype
self.gpu_memory = gpu_memory
self.cpu_memory = cpu_memory
self.tensor_parallel_size = tensor_parallel_size
config = AutoConfig.from_pretrained(model_name)
self.num_layers = config.num_hidden_layers
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = config.hidden_size // self.num_heads
self.ffn_size = config.intermediate_size
self.vocab_size = config.vocab_size
self.max_position = 8192
def get_param_size(self) -> int:
# NOTE: LLaMA does not tie the two embeddings.
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
lm_head = self.vocab_size * self.hidden_size // self.tensor_parallel_size
# NOTE: LLaMA does not have bias terms.
ln1 = self.hidden_size
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size
# Rotary embedding.
# TODO(woosuk): Share the rotary embedding between layers.
rot = self.max_position * self.head_size
mha = ln1 + q + k + v + out + rot
ln2 = self.hidden_size
gate = self.hidden_size * self.ffn_size // self.tensor_parallel_size
down = self.ffn_size * self.hidden_size // self.tensor_parallel_size
up = self.hidden_size * self.ffn_size // self.tensor_parallel_size
ffn = ln2 + gate + down + up
total = word_embedding + self.num_layers * (mha + ffn) + lm_head
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def get_max_act_size(
self,
max_num_batched_tokens: int,
) -> int:
# NOTE: We approxmiately calculate the maximum activation size by
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
# Double the activation size for input and output.
max_act = 2 * (max(qkv, ffn) + residual)
# Size of output logits.
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
max_act = max(max_act, output_logits)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * max_act
class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer):
def __init__(
self,
model_name: str,
block_size: int,
dtype: torch.dtype,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int,
) -> None:
self.model_name = model_name
self.block_size = block_size
self.dtype = dtype
self.gpu_memory = gpu_memory
self.cpu_memory = cpu_memory
self.tensor_parallel_size = tensor_parallel_size
config = AutoConfig.from_pretrained(model_name)
self.num_layers = config.num_hidden_layers
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = config.hidden_size // self.num_heads
self.ffn_size = config.intermediate_size
self.vocab_size = config.vocab_size
self.max_position = 8192
self.tie_word_embeddings = config.tie_word_embeddings
def get_param_size(self) -> int:
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
if self.tie_word_embeddings:
lm_head = 0
else:
lm_head = self.vocab_size * self.hidden_size // self.tensor_parallel_size
ln1 = 2 * self.hidden_size
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
# Rotary embedding.
# TODO(woosuk): Share the rotary embedding between layers.
rot = self.max_position * self.head_size
mha = ln1 + q + k + v + out + rot
ln2 = 2 * self.hidden_size
ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
ffn = ln2 + ffn1 + ffn2
total = word_embedding + self.num_layers * (mha + ffn) + lm_head
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def get_max_act_size(
self,
max_num_batched_tokens: int,
) -> int:
# NOTE: We approxmiately calculate the maximum activation size by
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
# Double the activation size for input and output.
max_act = 2 * (max(qkv, ffn) + residual)
# Size of output logits.
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
max_act = max(max_act, output_logits)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * max_act

View File

@ -5,9 +5,6 @@ import torch
import torch.nn as nn
from transformers import AutoConfig, PretrainedConfig
from cacheflow.model_executor.memory_analyzer import (
CacheFlowMemoryAnalyzer, GPT2MemoryAnalyzer, GPTNeoXMemoryAnalyzer,
LlamaMemoryAnalyzer, OPTMemoryAnalyzer)
from cacheflow.model_executor.models import (
GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
from cacheflow.model_executor.utils import get_torch_dtype
@ -22,14 +19,6 @@ _MODEL_REGISTRY = {
"OPTForCausalLM": OPTForCausalLM,
}
_MEMORY_ANALYZERS = {
"GPT2LMHeadModel": GPT2MemoryAnalyzer,
"GPTNeoXForCausalLM": GPTNeoXMemoryAnalyzer,
"LlamaForCausalLM": LlamaMemoryAnalyzer,
"OPTForCausalLM": OPTMemoryAnalyzer,
}
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
architectures = getattr(config, "architectures", [])
for arch in architectures:
@ -41,17 +30,6 @@ def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
)
def _get_memory_analyzer(config: PretrainedConfig) -> CacheFlowMemoryAnalyzer:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MEMORY_ANALYZERS:
return _MEMORY_ANALYZERS[arch]
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MEMORY_ANALYZERS.keys())}"
)
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
@ -100,18 +78,3 @@ def get_model(
model = model.cuda()
return model.eval(), torch_dtype
def get_memory_analyzer(
model_name: str,
block_size: int,
dtype: str,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int = 1,
) -> CacheFlowMemoryAnalyzer:
config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
memory_analyzer = _get_memory_analyzer(config)
return memory_analyzer(
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
tensor_parallel_size)

View File

@ -58,7 +58,8 @@ class GPT2Attention(nn.Module):
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, bias=True,
input_is_parallel=True,
perform_initialization=False)
self.attn = GPTCacheFlowAttention(scale=self.scale)
self.attn = GPTCacheFlowAttention(self.num_heads, self.head_dim,
scale=self.scale)
def forward(
self,

View File

@ -62,7 +62,8 @@ class GPTNeoXAttention(nn.Module):
scaling = self.head_size ** -0.5
rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0
self.attn = GPTNeoXCacheFlowAttention(scaling, rotary_dim)
self.attn = GPTNeoXCacheFlowAttention(self.num_heads, self.head_size,
scaling, rotary_dim)
def forward(
self,

View File

@ -104,7 +104,8 @@ class LlamaAttention(nn.Module):
input_is_parallel=True,
perform_initialization=False,
)
self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim)
self.attn = GPTNeoXCacheFlowAttention(self.num_heads, self.head_dim,
self.scaling, rotary_dim=self.head_dim)
def forward(
self,

View File

@ -74,7 +74,8 @@ class OPTAttention(nn.Module):
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
input_is_parallel=True,
perform_initialization=False)
self.attn = GPTCacheFlowAttention(scale=self.scaling)
self.attn = GPTCacheFlowAttention(self.num_heads, self.head_dim,
scale=self.scaling)
def forward(
self,

View File

@ -40,3 +40,15 @@ def set_random_seed(seed: int) -> None:
if model_parallel_is_initialized():
model_parallel_cuda_manual_seed(seed)
def get_cache_block_size(block_size: int,
num_heads: int,
head_size: int,
num_layers: int,
dtype: str) -> int:
key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = get_dtype_size(dtype)
return dtype_size * total

View File

@ -23,23 +23,18 @@ class Controller:
pipeline_parallel_size: int,
distributed_init_method: str,
model_name: str,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
dtype: str,
seed: int,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
max_num_batched_tokens: int,
max_num_sequences: int,
use_ray: bool,
) -> None:
self.stage_id = stage_id
self.stage_devices = stage_devices
self.model_name = model_name
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.use_ray = use_ray
# Which pipeline stage is this node assigned to?
@ -56,9 +51,6 @@ class Controller:
worker_cls = Worker
worker = worker_cls(
model_name=model_name,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
dtype=dtype,
seed=seed,
distributed_init_method=distributed_init_method,
@ -70,9 +62,44 @@ class Controller:
use_dummy_weights=use_dummy_weights,
use_np_cache=use_np_cache,
max_num_batched_tokens=max_num_batched_tokens,
max_num_sequences=max_num_sequences,
)
self.workers.append(worker)
def get_num_available_blocks(self, block_size: int, cpu_swap_space: int,
gpu_memory_utilization: float) -> List[Tuple[int, int]]:
all_worker_results = []
for worker in self.workers:
executor = worker.get_num_available_blocks
if self.use_ray:
executor = executor.remote
result = executor(
block_size,
cpu_swap_space,
gpu_memory_utilization,
)
all_worker_results.append(result)
if self.use_ray:
all_worker_results = ray.get(all_worker_results)
return all_worker_results
def init_cache_engine(self, block_size: int, num_gpu_blocks: int,
num_cpu_blocks: int):
all_worker_futures = []
for worker in self.workers:
executor = worker.init_cache_engine
if self.use_ray:
executor = executor.remote
future = executor(
block_size,
num_gpu_blocks,
num_cpu_blocks,
)
all_worker_futures.append(future)
if self.use_ray:
ray.get(all_worker_futures)
def set_next(
self,
next_node: Union['Controller', 'Scheduler'],

View File

@ -3,7 +3,8 @@ from typing import Dict, List, Optional, Tuple
import torch
from cacheflow.model_executor import get_model, InputMetadata, set_random_seed
from cacheflow.model_executor import (get_model, get_cache_block_size,
InputMetadata, set_random_seed)
from cacheflow.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel,
initialize_all_reduce_launcher,
@ -12,6 +13,7 @@ from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
SequenceOutputs)
from cacheflow.worker.cache_engine import CacheEngine
from cacheflow.utils import get_gpu_memory
class Worker:
@ -25,9 +27,6 @@ class Worker:
def __init__(
self,
model_name: str,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
dtype: str,
seed: int,
distributed_init_method: str,
@ -37,6 +36,7 @@ class Worker:
use_dummy_weights: bool,
use_np_cache: bool,
max_num_batched_tokens: int,
max_num_sequences: int,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
) -> None:
@ -46,8 +46,8 @@ class Worker:
tensor_parallel_size,
pipeline_parallel_size)
self.worker_id = rank
self.block_size = block_size
set_random_seed(seed)
self.seed = seed
set_random_seed(self.seed)
# Initialize the model.
self.model, self.dtype = get_model(
@ -55,8 +55,10 @@ class Worker:
use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache)
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
self.max_num_batched_tokens = max_num_batched_tokens
initialize_all_reduce_launcher(
max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
self.max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
self.max_num_sequences = max_num_sequences
self.num_layers = self.model.config.num_hidden_layers
assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
@ -66,12 +68,80 @@ class Worker:
# the random state is not affected by the model initialization.
set_random_seed(seed)
# Uninitialized cache engine. Will be initialized with
# self.init_cache_engine().
self.block_size = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None
@torch.inference_mode()
def get_num_available_blocks(
self, block_size: int, cpu_swap_space: int,
gpu_memory_utilization: float) -> Tuple[int, int]:
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99,
top_k=self.model.config.vocab_size - 1)
seqs = []
for group_id in range(self.max_num_sequences):
seq_len = (self.max_num_batched_tokens // self.max_num_sequences +
(group_id < self.max_num_batched_tokens %
self.max_num_sequences))
seq_data = SequenceData([0] * seq_len)
seq = SequenceGroupMetadata(
group_id=group_id,
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
)
seqs.append(seq)
input_tokens, input_positions, input_metadata = self.prepare_inputs(seqs)
# Execute the model.
self.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=[(None, None)] * self.num_layers,
input_metadata=input_metadata,
cache_events=None,
)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
peak_memory = torch.cuda.max_memory_allocated()
total_gpu_memory = get_gpu_memory()
cache_block_size = get_cache_block_size(block_size, self.num_heads,
self.head_size, self.num_layers,
self.dtype)
num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization
- peak_memory) // cache_block_size)
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
torch.cuda.empty_cache()
# Reset the seed to ensure that the model output is not affected by
# the profiling.
set_random_seed(self.seed)
return num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, block_size: int, num_gpu_blocks: int,
num_cpu_blocks: int):
self.block_size = block_size
self.cache_engine = CacheEngine(
worker_id=self.worker_id,
num_layers=self.num_layers,
num_heads=self.num_heads,
head_size=self.head_size,
block_size=block_size,
block_size=self.block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
dtype=self.dtype,
@ -129,6 +199,12 @@ class Worker:
# is always the first token in the sequence.
input_positions.extend(range(len(prompt_tokens)))
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.extend([0] * prompt_len)
continue
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
for i in range(prompt_len):