Use runtime profiling to replace manual memory analyzers (#81)
This commit is contained in:
parent
825d8892b5
commit
f756799b84
@ -6,15 +6,14 @@ try:
|
|||||||
import ray
|
import ray
|
||||||
except ImportError:
|
except ImportError:
|
||||||
ray = None
|
ray = None
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from cacheflow.core.scheduler import Scheduler
|
from cacheflow.core.scheduler import Scheduler
|
||||||
from cacheflow.frontend.simple_frontend import SimpleFrontend
|
from cacheflow.frontend.simple_frontend import SimpleFrontend
|
||||||
from cacheflow.logger import init_logger
|
from cacheflow.logger import init_logger
|
||||||
from cacheflow.model_executor import get_memory_analyzer
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.sequence import SequenceGroup
|
from cacheflow.sequence import SequenceGroup
|
||||||
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
|
||||||
from cacheflow.worker.controller import Controller, DeviceID
|
from cacheflow.worker.controller import Controller, DeviceID
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -34,14 +33,13 @@ class Server:
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
swap_space: int,
|
swap_space: int,
|
||||||
|
gpu_memory_utilization: float,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
max_num_sequences: int,
|
max_num_sequences: int,
|
||||||
num_nodes: int,
|
num_nodes: int,
|
||||||
num_devices_per_node: int,
|
num_devices_per_node: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
all_stage_devices: List[List[DeviceID]],
|
all_stage_devices: List[List[DeviceID]],
|
||||||
gpu_memory: int,
|
|
||||||
cpu_memory: int,
|
|
||||||
use_ray: bool,
|
use_ray: bool,
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
):
|
):
|
||||||
@ -63,19 +61,6 @@ class Server:
|
|||||||
assert self.world_size == 1, (
|
assert self.world_size == 1, (
|
||||||
"Only support single GPU without Ray.")
|
"Only support single GPU without Ray.")
|
||||||
|
|
||||||
self.memory_analyzer = get_memory_analyzer(
|
|
||||||
model_name=model,
|
|
||||||
block_size=block_size,
|
|
||||||
dtype=dtype,
|
|
||||||
gpu_memory=gpu_memory,
|
|
||||||
cpu_memory=cpu_memory,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
)
|
|
||||||
self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens)
|
|
||||||
self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
|
|
||||||
swap_space_gib=swap_space)
|
|
||||||
|
|
||||||
# Create a controller for each pipeline stage.
|
# Create a controller for each pipeline stage.
|
||||||
self.controllers: List[Controller] = []
|
self.controllers: List[Controller] = []
|
||||||
for i in range(pipeline_parallel_size):
|
for i in range(pipeline_parallel_size):
|
||||||
@ -87,19 +72,35 @@ class Server:
|
|||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
block_size=block_size,
|
|
||||||
num_gpu_blocks=self.num_gpu_blocks,
|
|
||||||
num_cpu_blocks=self.num_cpu_blocks,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
use_dummy_weights=use_dummy_weights,
|
use_dummy_weights=use_dummy_weights,
|
||||||
use_np_cache=use_np_cache,
|
use_np_cache=use_np_cache,
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_num_sequences=max_num_sequences,
|
||||||
use_ray=use_ray,
|
use_ray=use_ray,
|
||||||
)
|
)
|
||||||
self.controllers.append(controller)
|
self.controllers.append(controller)
|
||||||
|
|
||||||
|
# Initialize cache engine.
|
||||||
|
all_worker_num_available_blocks = []
|
||||||
|
for controller in self.controllers:
|
||||||
|
all_worker_num_available_blocks.extend(
|
||||||
|
controller.get_num_available_blocks(
|
||||||
|
block_size, swap_space, gpu_memory_utilization)
|
||||||
|
)
|
||||||
|
# Since we use a shared centralized controller, we take the minimum
|
||||||
|
# number of blocks across all workers to make sure all the memory
|
||||||
|
# operators can be applied to all workers.
|
||||||
|
self.num_gpu_blocks = np.min([b[0] for b in all_worker_num_available_blocks])
|
||||||
|
self.num_cpu_blocks = np.min([b[1] for b in all_worker_num_available_blocks])
|
||||||
|
logger.info(f'# GPU blocks: {self.num_gpu_blocks}, '
|
||||||
|
f'# CPU blocks: {self.num_cpu_blocks}')
|
||||||
|
for controller in self.controllers:
|
||||||
|
controller.init_cache_engine(block_size, self.num_gpu_blocks,
|
||||||
|
self.num_cpu_blocks)
|
||||||
|
|
||||||
# Create a scheduler.
|
# Create a scheduler.
|
||||||
self.scheduler = Scheduler(
|
self.scheduler = Scheduler(
|
||||||
controllers=self.controllers,
|
controllers=self.controllers,
|
||||||
@ -214,7 +215,11 @@ def initialize_cluster(
|
|||||||
all_stage_devices)
|
all_stage_devices)
|
||||||
|
|
||||||
|
|
||||||
|
_GiB = 1 << 30
|
||||||
|
|
||||||
|
|
||||||
def add_server_arguments(parser: argparse.ArgumentParser):
|
def add_server_arguments(parser: argparse.ArgumentParser):
|
||||||
|
"""Shared arguments for CacheFlow servers."""
|
||||||
# Model arguments
|
# Model arguments
|
||||||
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
|
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
|
||||||
parser.add_argument('--cache-dir', type=str, default=None,
|
parser.add_argument('--cache-dir', type=str, default=None,
|
||||||
@ -238,6 +243,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
|||||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||||
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
||||||
|
parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor')
|
||||||
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
|
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
|
||||||
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
|
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
|
||||||
parser.add_argument('--log-stats', action='store_true', help='log system statistics')
|
parser.add_argument('--log-stats', action='store_true', help='log system statistics')
|
||||||
@ -245,8 +251,11 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
|||||||
|
|
||||||
|
|
||||||
def process_server_arguments(args: argparse.Namespace):
|
def process_server_arguments(args: argparse.Namespace):
|
||||||
|
"""Post process the parsed arguments."""
|
||||||
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
|
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
|
||||||
args.use_ray = True
|
args.use_ray = True
|
||||||
|
args.swap_space = args.swap_space * _GiB
|
||||||
|
args.max_num_sequences = min(args.max_num_sequences, args.max_num_batched_tokens)
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
@ -274,14 +283,13 @@ def init_local_server_and_frontend_with_arguments(args: argparse.Namespace):
|
|||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
swap_space=args.swap_space,
|
swap_space=args.swap_space,
|
||||||
|
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||||
max_num_sequences=args.max_num_sequences,
|
max_num_sequences=args.max_num_sequences,
|
||||||
num_nodes=num_nodes,
|
num_nodes=num_nodes,
|
||||||
num_devices_per_node=num_devices_per_node,
|
num_devices_per_node=num_devices_per_node,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
all_stage_devices=all_stage_devices,
|
all_stage_devices=all_stage_devices,
|
||||||
gpu_memory=get_gpu_memory(),
|
|
||||||
cpu_memory=get_cpu_memory(),
|
|
||||||
use_ray=args.use_ray,
|
use_ray=args.use_ray,
|
||||||
log_stats=args.log_stats,
|
log_stats=args.log_stats,
|
||||||
)
|
)
|
||||||
|
@ -15,7 +15,7 @@ from cacheflow.core.server import (Server, add_server_arguments,
|
|||||||
from cacheflow.frontend.utils import get_tokenizer
|
from cacheflow.frontend.utils import get_tokenizer
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.sequence import Sequence, SequenceGroup
|
from cacheflow.sequence import Sequence, SequenceGroup
|
||||||
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
|
from cacheflow.utils import Counter
|
||||||
from cacheflow.worker.controller import DeviceID
|
from cacheflow.worker.controller import DeviceID
|
||||||
|
|
||||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||||
@ -34,6 +34,7 @@ class FastAPIServer:
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
swap_space: int,
|
swap_space: int,
|
||||||
|
gpu_memory_utilization: float,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
max_num_sequences: int,
|
max_num_sequences: int,
|
||||||
num_nodes: int,
|
num_nodes: int,
|
||||||
@ -41,6 +42,7 @@ class FastAPIServer:
|
|||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
all_stage_devices: List[List[DeviceID]],
|
all_stage_devices: List[List[DeviceID]],
|
||||||
server_use_ray: bool,
|
server_use_ray: bool,
|
||||||
|
log_stats: bool,
|
||||||
):
|
):
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
|
|
||||||
@ -62,15 +64,15 @@ class FastAPIServer:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
swap_space=swap_space,
|
swap_space=swap_space,
|
||||||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
max_num_sequences=max_num_sequences,
|
max_num_sequences=max_num_sequences,
|
||||||
num_nodes=num_nodes,
|
num_nodes=num_nodes,
|
||||||
num_devices_per_node=num_devices_per_node,
|
num_devices_per_node=num_devices_per_node,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
all_stage_devices=all_stage_devices,
|
all_stage_devices=all_stage_devices,
|
||||||
gpu_memory=get_gpu_memory(),
|
|
||||||
cpu_memory=get_cpu_memory(),
|
|
||||||
use_ray=server_use_ray,
|
use_ray=server_use_ray,
|
||||||
|
log_stats=log_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.running_seq_groups: Dict[int, SequenceGroup] = {}
|
self.running_seq_groups: Dict[int, SequenceGroup] = {}
|
||||||
@ -182,6 +184,7 @@ if __name__ == "__main__":
|
|||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
swap_space=args.swap_space,
|
swap_space=args.swap_space,
|
||||||
|
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||||
max_num_sequences=args.max_num_sequences,
|
max_num_sequences=args.max_num_sequences,
|
||||||
num_nodes=num_nodes,
|
num_nodes=num_nodes,
|
||||||
@ -189,6 +192,7 @@ if __name__ == "__main__":
|
|||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
all_stage_devices=all_stage_devices,
|
all_stage_devices=all_stage_devices,
|
||||||
server_use_ray=args.use_ray,
|
server_use_ray=args.use_ray,
|
||||||
|
log_stats=args.log_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||||
from cacheflow.model_executor.model_loader import get_model, get_memory_analyzer
|
from cacheflow.model_executor.model_loader import get_model
|
||||||
from cacheflow.model_executor.utils import set_random_seed
|
from cacheflow.model_executor.utils import (set_random_seed,
|
||||||
|
get_cache_block_size)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InputMetadata",
|
"InputMetadata",
|
||||||
|
"get_cache_block_size",
|
||||||
"get_model",
|
"get_model",
|
||||||
"get_memory_analyzer",
|
|
||||||
"set_random_seed",
|
"set_random_seed",
|
||||||
]
|
]
|
||||||
|
@ -11,6 +11,8 @@ from cacheflow import pos_encoding_ops
|
|||||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
|
_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]
|
||||||
|
|
||||||
class GPTCacheFlowAttention(nn.Module):
|
class GPTCacheFlowAttention(nn.Module):
|
||||||
"""GPT-style multi-head attention.
|
"""GPT-style multi-head attention.
|
||||||
|
|
||||||
@ -39,11 +41,19 @@ class GPTCacheFlowAttention(nn.Module):
|
|||||||
5. Output a flattened 1D tensor.
|
5. Output a flattened 1D tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, scale: float) -> None:
|
def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
self.attn_op = xops.fmha.cutlass.FwOp()
|
||||||
|
|
||||||
|
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
||||||
|
raise ValueError(f'head_size ({self.head_size}) is not supported by '
|
||||||
|
'the single_query_cached_kv_attention kernel. '
|
||||||
|
'Use one of the following head sizes: '
|
||||||
|
f'{_SUPPORTED_HEAD_SIZES}.')
|
||||||
|
|
||||||
def multi_query_kv_attention(
|
def multi_query_kv_attention(
|
||||||
self,
|
self,
|
||||||
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||||
@ -74,14 +84,6 @@ class GPTCacheFlowAttention(nn.Module):
|
|||||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> None:
|
) -> None:
|
||||||
head_size = value_cache.shape[2]
|
|
||||||
supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
|
|
||||||
if head_size not in supported_head_sizes:
|
|
||||||
raise ValueError(f'head_size ({head_size}) is not supported by '
|
|
||||||
'the single_query_cached_kv_attention kernel. '
|
|
||||||
'Use one of the following head sizes: '
|
|
||||||
f'{supported_head_sizes}.')
|
|
||||||
|
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
attention_ops.single_query_cached_kv_attention(
|
attention_ops.single_query_cached_kv_attention(
|
||||||
output,
|
output,
|
||||||
@ -100,8 +102,8 @@ class GPTCacheFlowAttention(nn.Module):
|
|||||||
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size]
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_event: Optional[torch.cuda.Event],
|
cache_event: Optional[torch.cuda.Event],
|
||||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||||
@ -109,11 +111,9 @@ class GPTCacheFlowAttention(nn.Module):
|
|||||||
# tensor of shape [num_tokens, 3 * num_heads * head_size].
|
# tensor of shape [num_tokens, 3 * num_heads * head_size].
|
||||||
|
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
num_heads = value_cache.shape[1]
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
head_size = value_cache.shape[2]
|
key = key.view(-1, self.num_heads, self.head_size)
|
||||||
query = query.view(-1, num_heads, head_size)
|
value = value.view(-1, self.num_heads, self.head_size)
|
||||||
key = key.view(-1, num_heads, head_size)
|
|
||||||
value = value.view(-1, num_heads, head_size)
|
|
||||||
|
|
||||||
# Pre-allocate the output tensor.
|
# Pre-allocate the output tensor.
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
@ -134,8 +134,11 @@ class GPTCacheFlowAttention(nn.Module):
|
|||||||
cache_event.wait()
|
cache_event.wait()
|
||||||
|
|
||||||
# Reshape the keys and values and store them in the cache.
|
# Reshape the keys and values and store them in the cache.
|
||||||
|
# When key_cache and value_cache are not provided, the new key
|
||||||
|
# and value vectors will not be cached.
|
||||||
num_valid_tokens = input_metadata.num_valid_tokens
|
num_valid_tokens = input_metadata.num_valid_tokens
|
||||||
if num_valid_tokens > 0:
|
if (num_valid_tokens > 0 and key_cache is not None
|
||||||
|
and value_cache is not None):
|
||||||
# The stride is 3 because the key and value are sliced from qkv.
|
# The stride is 3 because the key and value are sliced from qkv.
|
||||||
cache_ops.reshape_and_cache(
|
cache_ops.reshape_and_cache(
|
||||||
key[:num_valid_tokens],
|
key[:num_valid_tokens],
|
||||||
@ -146,6 +149,10 @@ class GPTCacheFlowAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if input_metadata.num_generation_tokens > 0:
|
if input_metadata.num_generation_tokens > 0:
|
||||||
|
assert key_cache is not None and value_cache is not None, (
|
||||||
|
"key_cache and value_cache must be provided when "
|
||||||
|
"generating tokens."
|
||||||
|
)
|
||||||
# Compute the attention op for generation tokens.
|
# Compute the attention op for generation tokens.
|
||||||
self.single_query_cached_kv_attention(
|
self.single_query_cached_kv_attention(
|
||||||
output[num_prompt_tokens:num_valid_tokens],
|
output[num_prompt_tokens:num_valid_tokens],
|
||||||
@ -156,7 +163,7 @@ class GPTCacheFlowAttention(nn.Module):
|
|||||||
|
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
# NOTE(woosuk): The output tensor may include paddings.
|
# NOTE(woosuk): The output tensor may include paddings.
|
||||||
return output.view(-1, num_heads * head_size)
|
return output.view(-1, self.num_heads * self.head_size)
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||||
@ -164,12 +171,14 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position: int = 8192,
|
max_position: int = 8192,
|
||||||
base: int = 10000,
|
base: int = 10000,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(scale)
|
super().__init__(num_heads, head_size, scale)
|
||||||
|
|
||||||
# Create the cos and sin cache.
|
# Create the cos and sin cache.
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
|
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||||
@ -199,12 +208,11 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
|||||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||||
# Apply rotary embedding to the query and key before passing them
|
# Apply rotary embedding to the query and key before passing them
|
||||||
# to the attention op.
|
# to the attention op.
|
||||||
head_size = value_cache.shape[2]
|
|
||||||
pos_encoding_ops.rotary_embedding_neox(
|
pos_encoding_ops.rotary_embedding_neox(
|
||||||
positions,
|
positions,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
head_size,
|
self.head_size,
|
||||||
self.cos_sin_cache,
|
self.cos_sin_cache,
|
||||||
)
|
)
|
||||||
return super().forward(
|
return super().forward(
|
||||||
|
@ -74,7 +74,7 @@ class Sampler(nn.Module):
|
|||||||
# Apply top-p and top-k truncation.
|
# Apply top-p and top-k truncation.
|
||||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||||
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
||||||
if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks):
|
if any(p < 1.0 for p in top_ps) or any(k != self.vocab_size for k in top_ks):
|
||||||
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens.
|
||||||
|
@ -1,370 +0,0 @@
|
|||||||
import torch
|
|
||||||
from transformers import AutoConfig
|
|
||||||
|
|
||||||
from cacheflow.logger import init_logger
|
|
||||||
from cacheflow.model_executor.utils import get_dtype_size
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
_GiB = 1 << 30
|
|
||||||
|
|
||||||
|
|
||||||
class CacheFlowMemoryAnalyzer:
|
|
||||||
|
|
||||||
def get_max_num_gpu_blocks(
|
|
||||||
self,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
memory_utilization: float,
|
|
||||||
) -> int:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_workspace_size(self) -> int:
|
|
||||||
return 1 * _GiB
|
|
||||||
|
|
||||||
def get_cache_block_size(self) -> int:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_max_num_cpu_blocks(
|
|
||||||
self,
|
|
||||||
swap_space_gib: int,
|
|
||||||
) -> int:
|
|
||||||
swap_space = swap_space_gib * _GiB
|
|
||||||
cpu_memory = self.cpu_memory
|
|
||||||
if swap_space > 0.8 * cpu_memory:
|
|
||||||
raise ValueError(f'The swap space ({swap_space_gib:.2f} GiB) '
|
|
||||||
'takes more than 80% of the available memory '
|
|
||||||
f'({cpu_memory / _GiB:.2f} GiB).'
|
|
||||||
'Please check the swap space size.')
|
|
||||||
if swap_space > 0.5 * cpu_memory:
|
|
||||||
logger.info(f'WARNING: The swap space ({swap_space_gib:.2f} GiB) '
|
|
||||||
'takes more than 50% of the available memory '
|
|
||||||
f'({cpu_memory / _GiB:.2f} GiB).'
|
|
||||||
'This may slow the system performance.')
|
|
||||||
max_num_blocks = swap_space // self.get_cache_block_size()
|
|
||||||
return max_num_blocks
|
|
||||||
|
|
||||||
def get_param_size(self) -> int:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_max_act_size(self, max_num_batched_tokens: int) -> int:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_cache_block_size(self) -> int:
|
|
||||||
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
value_cache_block = key_cache_block
|
|
||||||
total = self.num_layers * (key_cache_block + value_cache_block)
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * total
|
|
||||||
|
|
||||||
def get_max_num_gpu_blocks(
|
|
||||||
self,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
memory_utilization: float = 0.95,
|
|
||||||
) -> int:
|
|
||||||
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
|
|
||||||
usable_memory = int(memory_utilization * self.gpu_memory)
|
|
||||||
|
|
||||||
param_size = self.get_param_size()
|
|
||||||
act_size = self.get_max_act_size(max_num_batched_tokens)
|
|
||||||
workspace_size = self.get_workspace_size()
|
|
||||||
|
|
||||||
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
|
|
||||||
if max_cache_size <= 0:
|
|
||||||
raise RuntimeError('Not enough GPU memory.')
|
|
||||||
max_num_blocks = max_cache_size // self.get_cache_block_size()
|
|
||||||
return max_num_blocks
|
|
||||||
|
|
||||||
|
|
||||||
class GPT2MemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
block_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
gpu_memory: int,
|
|
||||||
cpu_memory: int,
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
) -> None:
|
|
||||||
self.model_name = model_name
|
|
||||||
self.block_size = block_size
|
|
||||||
self.dtype = dtype
|
|
||||||
self.gpu_memory = gpu_memory
|
|
||||||
self.cpu_memory = cpu_memory
|
|
||||||
self.tensor_parallel_size = tensor_parallel_size
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_name)
|
|
||||||
self.num_layers = config.num_hidden_layers
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.num_heads = config.num_attention_heads
|
|
||||||
self.head_size = config.hidden_size // self.num_heads
|
|
||||||
self.ffn_size = config.n_inner if config.n_inner is not None else 4 * self.hidden_size
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
self.max_position = config.max_position_embeddings
|
|
||||||
|
|
||||||
def get_param_size(self) -> int:
|
|
||||||
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
position_embedding = self.max_position * self.hidden_size
|
|
||||||
|
|
||||||
ln1 = 2 * self.hidden_size
|
|
||||||
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
mha = ln1 + q + k + v + out
|
|
||||||
|
|
||||||
ln2 = 2 * self.hidden_size
|
|
||||||
ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
|
|
||||||
ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
ffn = ln2 + ffn1 + ffn2
|
|
||||||
|
|
||||||
total = (word_embedding + position_embedding +
|
|
||||||
self.num_layers * (mha + ffn))
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * total
|
|
||||||
|
|
||||||
def get_max_act_size(
|
|
||||||
self,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
) -> int:
|
|
||||||
# NOTE: We approxmiately calculate the maximum activation size by
|
|
||||||
# estimating
|
|
||||||
# 1) the maximum activation tensor size during inference
|
|
||||||
# 2) the residual tensor size during inference
|
|
||||||
# Here, we assume that FlashAttention is used and
|
|
||||||
# thus the attention maps are never materialized in GPU DRAM.
|
|
||||||
residual = max_num_batched_tokens * self.hidden_size
|
|
||||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
|
||||||
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
|
|
||||||
# Double the activation size for input and output.
|
|
||||||
max_act = 2 * (max(qkv, ffn) + residual)
|
|
||||||
# Size of output logits.
|
|
||||||
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
|
|
||||||
max_act = max(max_act, output_logits)
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * max_act
|
|
||||||
|
|
||||||
|
|
||||||
class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
block_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
gpu_memory: int,
|
|
||||||
cpu_memory: int,
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
) -> None:
|
|
||||||
self.model_name = model_name
|
|
||||||
self.block_size = block_size
|
|
||||||
self.dtype = dtype
|
|
||||||
self.gpu_memory = gpu_memory
|
|
||||||
self.cpu_memory = cpu_memory
|
|
||||||
self.tensor_parallel_size = tensor_parallel_size
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_name)
|
|
||||||
self.num_layers = config.num_hidden_layers
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.num_heads = config.num_attention_heads
|
|
||||||
self.head_size = config.hidden_size // self.num_heads
|
|
||||||
self.ffn_size = config.ffn_dim
|
|
||||||
self.embedding_size = config.word_embed_proj_dim
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
self.max_position = config.max_position_embeddings
|
|
||||||
|
|
||||||
def get_param_size(self) -> int:
|
|
||||||
word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size
|
|
||||||
if self.embedding_size != self.hidden_size:
|
|
||||||
# Project in/out.
|
|
||||||
word_embedding += 2 * self.embedding_size * self.hidden_size
|
|
||||||
position_embedding = self.max_position * self.hidden_size
|
|
||||||
|
|
||||||
ln1 = 2 * self.hidden_size
|
|
||||||
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
mha = ln1 + q + k + v + out
|
|
||||||
|
|
||||||
ln2 = 2 * self.hidden_size
|
|
||||||
ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
|
|
||||||
ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
ffn = ln2 + ffn1 + ffn2
|
|
||||||
|
|
||||||
total = (word_embedding + position_embedding +
|
|
||||||
self.num_layers * (mha + ffn))
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * total
|
|
||||||
|
|
||||||
def get_max_act_size(
|
|
||||||
self,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
) -> int:
|
|
||||||
# NOTE: We approxmiately calculate the maximum activation size by
|
|
||||||
# estimating
|
|
||||||
# 1) the maximum activation tensor size during inference
|
|
||||||
# 2) the residual tensor size during inference
|
|
||||||
# Here, we assume that we use memory-efficient attention which
|
|
||||||
# does not materialize the attention maps in GPU DRAM.
|
|
||||||
residual = max_num_batched_tokens * self.hidden_size
|
|
||||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
|
||||||
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
|
|
||||||
# Double the activation size for input and output.
|
|
||||||
max_act = 2 * (max(qkv, ffn) + residual)
|
|
||||||
# Size of output logits.
|
|
||||||
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
|
|
||||||
max_act = max(max_act, output_logits)
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * max_act
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
block_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
gpu_memory: int,
|
|
||||||
cpu_memory: int,
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
) -> None:
|
|
||||||
self.model_name = model_name
|
|
||||||
self.block_size = block_size
|
|
||||||
self.dtype = dtype
|
|
||||||
self.gpu_memory = gpu_memory
|
|
||||||
self.cpu_memory = cpu_memory
|
|
||||||
self.tensor_parallel_size = tensor_parallel_size
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_name)
|
|
||||||
self.num_layers = config.num_hidden_layers
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.num_heads = config.num_attention_heads
|
|
||||||
self.head_size = config.hidden_size // self.num_heads
|
|
||||||
self.ffn_size = config.intermediate_size
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
self.max_position = 8192
|
|
||||||
|
|
||||||
def get_param_size(self) -> int:
|
|
||||||
# NOTE: LLaMA does not tie the two embeddings.
|
|
||||||
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
lm_head = self.vocab_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
|
|
||||||
# NOTE: LLaMA does not have bias terms.
|
|
||||||
ln1 = self.hidden_size
|
|
||||||
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
# Rotary embedding.
|
|
||||||
# TODO(woosuk): Share the rotary embedding between layers.
|
|
||||||
rot = self.max_position * self.head_size
|
|
||||||
mha = ln1 + q + k + v + out + rot
|
|
||||||
|
|
||||||
ln2 = self.hidden_size
|
|
||||||
gate = self.hidden_size * self.ffn_size // self.tensor_parallel_size
|
|
||||||
down = self.ffn_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
up = self.hidden_size * self.ffn_size // self.tensor_parallel_size
|
|
||||||
ffn = ln2 + gate + down + up
|
|
||||||
|
|
||||||
total = word_embedding + self.num_layers * (mha + ffn) + lm_head
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * total
|
|
||||||
|
|
||||||
def get_max_act_size(
|
|
||||||
self,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
) -> int:
|
|
||||||
# NOTE: We approxmiately calculate the maximum activation size by
|
|
||||||
# estimating
|
|
||||||
# 1) the maximum activation tensor size during inference
|
|
||||||
# 2) the residual tensor size during inference
|
|
||||||
# Here, we assume that we use memory-efficient attention which
|
|
||||||
# does not materialize the attention maps in GPU DRAM.
|
|
||||||
residual = max_num_batched_tokens * self.hidden_size
|
|
||||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
|
||||||
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
|
|
||||||
# Double the activation size for input and output.
|
|
||||||
max_act = 2 * (max(qkv, ffn) + residual)
|
|
||||||
# Size of output logits.
|
|
||||||
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
|
|
||||||
max_act = max(max_act, output_logits)
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * max_act
|
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
block_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
gpu_memory: int,
|
|
||||||
cpu_memory: int,
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
) -> None:
|
|
||||||
self.model_name = model_name
|
|
||||||
self.block_size = block_size
|
|
||||||
self.dtype = dtype
|
|
||||||
self.gpu_memory = gpu_memory
|
|
||||||
self.cpu_memory = cpu_memory
|
|
||||||
self.tensor_parallel_size = tensor_parallel_size
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_name)
|
|
||||||
self.num_layers = config.num_hidden_layers
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.num_heads = config.num_attention_heads
|
|
||||||
self.head_size = config.hidden_size // self.num_heads
|
|
||||||
self.ffn_size = config.intermediate_size
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
self.max_position = 8192
|
|
||||||
self.tie_word_embeddings = config.tie_word_embeddings
|
|
||||||
|
|
||||||
def get_param_size(self) -> int:
|
|
||||||
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
if self.tie_word_embeddings:
|
|
||||||
lm_head = 0
|
|
||||||
else:
|
|
||||||
lm_head = self.vocab_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
|
|
||||||
ln1 = 2 * self.hidden_size
|
|
||||||
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
# Rotary embedding.
|
|
||||||
# TODO(woosuk): Share the rotary embedding between layers.
|
|
||||||
rot = self.max_position * self.head_size
|
|
||||||
mha = ln1 + q + k + v + out + rot
|
|
||||||
|
|
||||||
ln2 = 2 * self.hidden_size
|
|
||||||
ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
|
|
||||||
ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
ffn = ln2 + ffn1 + ffn2
|
|
||||||
|
|
||||||
total = word_embedding + self.num_layers * (mha + ffn) + lm_head
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * total
|
|
||||||
|
|
||||||
def get_max_act_size(
|
|
||||||
self,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
) -> int:
|
|
||||||
# NOTE: We approxmiately calculate the maximum activation size by
|
|
||||||
# estimating
|
|
||||||
# 1) the maximum activation tensor size during inference
|
|
||||||
# 2) the residual tensor size during inference
|
|
||||||
# Here, we assume that we use memory-efficient attention which
|
|
||||||
# does not materialize the attention maps in GPU DRAM.
|
|
||||||
residual = max_num_batched_tokens * self.hidden_size
|
|
||||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
|
||||||
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
|
|
||||||
# Double the activation size for input and output.
|
|
||||||
max_act = 2 * (max(qkv, ffn) + residual)
|
|
||||||
# Size of output logits.
|
|
||||||
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
|
|
||||||
max_act = max(max_act, output_logits)
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * max_act
|
|
@ -5,9 +5,6 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import AutoConfig, PretrainedConfig
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
from cacheflow.model_executor.memory_analyzer import (
|
|
||||||
CacheFlowMemoryAnalyzer, GPT2MemoryAnalyzer, GPTNeoXMemoryAnalyzer,
|
|
||||||
LlamaMemoryAnalyzer, OPTMemoryAnalyzer)
|
|
||||||
from cacheflow.model_executor.models import (
|
from cacheflow.model_executor.models import (
|
||||||
GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
|
GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
|
||||||
from cacheflow.model_executor.utils import get_torch_dtype
|
from cacheflow.model_executor.utils import get_torch_dtype
|
||||||
@ -22,14 +19,6 @@ _MODEL_REGISTRY = {
|
|||||||
"OPTForCausalLM": OPTForCausalLM,
|
"OPTForCausalLM": OPTForCausalLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
_MEMORY_ANALYZERS = {
|
|
||||||
"GPT2LMHeadModel": GPT2MemoryAnalyzer,
|
|
||||||
"GPTNeoXForCausalLM": GPTNeoXMemoryAnalyzer,
|
|
||||||
"LlamaForCausalLM": LlamaMemoryAnalyzer,
|
|
||||||
"OPTForCausalLM": OPTMemoryAnalyzer,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
|
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
|
||||||
architectures = getattr(config, "architectures", [])
|
architectures = getattr(config, "architectures", [])
|
||||||
for arch in architectures:
|
for arch in architectures:
|
||||||
@ -41,17 +30,6 @@ def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_memory_analyzer(config: PretrainedConfig) -> CacheFlowMemoryAnalyzer:
|
|
||||||
architectures = getattr(config, "architectures", [])
|
|
||||||
for arch in architectures:
|
|
||||||
if arch in _MEMORY_ANALYZERS:
|
|
||||||
return _MEMORY_ANALYZERS[arch]
|
|
||||||
raise ValueError(
|
|
||||||
f"Model architectures {architectures} are not supported for now. "
|
|
||||||
f"Supported architectures: {list(_MEMORY_ANALYZERS.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
|
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
|
||||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||||
# because config.torch_dtype can be None.
|
# because config.torch_dtype can be None.
|
||||||
@ -100,18 +78,3 @@ def get_model(
|
|||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
return model.eval(), torch_dtype
|
return model.eval(), torch_dtype
|
||||||
|
|
||||||
|
|
||||||
def get_memory_analyzer(
|
|
||||||
model_name: str,
|
|
||||||
block_size: int,
|
|
||||||
dtype: str,
|
|
||||||
gpu_memory: int,
|
|
||||||
cpu_memory: int,
|
|
||||||
tensor_parallel_size: int = 1,
|
|
||||||
) -> CacheFlowMemoryAnalyzer:
|
|
||||||
config = AutoConfig.from_pretrained(model_name)
|
|
||||||
torch_dtype = _get_dtype(config, dtype)
|
|
||||||
memory_analyzer = _get_memory_analyzer(config)
|
|
||||||
return memory_analyzer(
|
|
||||||
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
|
|
||||||
tensor_parallel_size)
|
|
||||||
|
@ -58,7 +58,8 @@ class GPT2Attention(nn.Module):
|
|||||||
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, bias=True,
|
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, bias=True,
|
||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
perform_initialization=False)
|
perform_initialization=False)
|
||||||
self.attn = GPTCacheFlowAttention(scale=self.scale)
|
self.attn = GPTCacheFlowAttention(self.num_heads, self.head_dim,
|
||||||
|
scale=self.scale)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -62,7 +62,8 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
scaling = self.head_size ** -0.5
|
scaling = self.head_size ** -0.5
|
||||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||||
assert rotary_dim % 2 == 0
|
assert rotary_dim % 2 == 0
|
||||||
self.attn = GPTNeoXCacheFlowAttention(scaling, rotary_dim)
|
self.attn = GPTNeoXCacheFlowAttention(self.num_heads, self.head_size,
|
||||||
|
scaling, rotary_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -104,7 +104,8 @@ class LlamaAttention(nn.Module):
|
|||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
perform_initialization=False,
|
perform_initialization=False,
|
||||||
)
|
)
|
||||||
self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim)
|
self.attn = GPTNeoXCacheFlowAttention(self.num_heads, self.head_dim,
|
||||||
|
self.scaling, rotary_dim=self.head_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -74,7 +74,8 @@ class OPTAttention(nn.Module):
|
|||||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
perform_initialization=False)
|
perform_initialization=False)
|
||||||
self.attn = GPTCacheFlowAttention(scale=self.scaling)
|
self.attn = GPTCacheFlowAttention(self.num_heads, self.head_dim,
|
||||||
|
scale=self.scaling)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -40,3 +40,15 @@ def set_random_seed(seed: int) -> None:
|
|||||||
|
|
||||||
if model_parallel_is_initialized():
|
if model_parallel_is_initialized():
|
||||||
model_parallel_cuda_manual_seed(seed)
|
model_parallel_cuda_manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_block_size(block_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
num_layers: int,
|
||||||
|
dtype: str) -> int:
|
||||||
|
key_cache_block = block_size * num_heads * head_size
|
||||||
|
value_cache_block = key_cache_block
|
||||||
|
total = num_layers * (key_cache_block + value_cache_block)
|
||||||
|
dtype_size = get_dtype_size(dtype)
|
||||||
|
return dtype_size * total
|
||||||
|
@ -23,23 +23,18 @@ class Controller:
|
|||||||
pipeline_parallel_size: int,
|
pipeline_parallel_size: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
block_size: int,
|
|
||||||
num_gpu_blocks: int,
|
|
||||||
num_cpu_blocks: int,
|
|
||||||
dtype: str,
|
dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
cache_dir: Optional[str],
|
cache_dir: Optional[str],
|
||||||
use_dummy_weights: bool,
|
use_dummy_weights: bool,
|
||||||
use_np_cache: bool,
|
use_np_cache: bool,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
|
max_num_sequences: int,
|
||||||
use_ray: bool,
|
use_ray: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.stage_id = stage_id
|
self.stage_id = stage_id
|
||||||
self.stage_devices = stage_devices
|
self.stage_devices = stage_devices
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.block_size = block_size
|
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
|
||||||
self.num_cpu_blocks = num_cpu_blocks
|
|
||||||
self.use_ray = use_ray
|
self.use_ray = use_ray
|
||||||
|
|
||||||
# Which pipeline stage is this node assigned to?
|
# Which pipeline stage is this node assigned to?
|
||||||
@ -56,9 +51,6 @@ class Controller:
|
|||||||
worker_cls = Worker
|
worker_cls = Worker
|
||||||
worker = worker_cls(
|
worker = worker_cls(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
block_size=block_size,
|
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
|
||||||
num_cpu_blocks=num_cpu_blocks,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
@ -70,9 +62,44 @@ class Controller:
|
|||||||
use_dummy_weights=use_dummy_weights,
|
use_dummy_weights=use_dummy_weights,
|
||||||
use_np_cache=use_np_cache,
|
use_np_cache=use_np_cache,
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_num_sequences=max_num_sequences,
|
||||||
)
|
)
|
||||||
self.workers.append(worker)
|
self.workers.append(worker)
|
||||||
|
|
||||||
|
def get_num_available_blocks(self, block_size: int, cpu_swap_space: int,
|
||||||
|
gpu_memory_utilization: float) -> List[Tuple[int, int]]:
|
||||||
|
all_worker_results = []
|
||||||
|
for worker in self.workers:
|
||||||
|
executor = worker.get_num_available_blocks
|
||||||
|
if self.use_ray:
|
||||||
|
executor = executor.remote
|
||||||
|
|
||||||
|
result = executor(
|
||||||
|
block_size,
|
||||||
|
cpu_swap_space,
|
||||||
|
gpu_memory_utilization,
|
||||||
|
)
|
||||||
|
all_worker_results.append(result)
|
||||||
|
if self.use_ray:
|
||||||
|
all_worker_results = ray.get(all_worker_results)
|
||||||
|
return all_worker_results
|
||||||
|
|
||||||
|
def init_cache_engine(self, block_size: int, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int):
|
||||||
|
all_worker_futures = []
|
||||||
|
for worker in self.workers:
|
||||||
|
executor = worker.init_cache_engine
|
||||||
|
if self.use_ray:
|
||||||
|
executor = executor.remote
|
||||||
|
future = executor(
|
||||||
|
block_size,
|
||||||
|
num_gpu_blocks,
|
||||||
|
num_cpu_blocks,
|
||||||
|
)
|
||||||
|
all_worker_futures.append(future)
|
||||||
|
if self.use_ray:
|
||||||
|
ray.get(all_worker_futures)
|
||||||
|
|
||||||
def set_next(
|
def set_next(
|
||||||
self,
|
self,
|
||||||
next_node: Union['Controller', 'Scheduler'],
|
next_node: Union['Controller', 'Scheduler'],
|
||||||
|
@ -3,7 +3,8 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from cacheflow.model_executor import get_model, InputMetadata, set_random_seed
|
from cacheflow.model_executor import (get_model, get_cache_block_size,
|
||||||
|
InputMetadata, set_random_seed)
|
||||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
initialize_all_reduce_launcher,
|
initialize_all_reduce_launcher,
|
||||||
@ -12,6 +13,7 @@ from cacheflow.sampling_params import SamplingParams
|
|||||||
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
|
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
|
||||||
SequenceOutputs)
|
SequenceOutputs)
|
||||||
from cacheflow.worker.cache_engine import CacheEngine
|
from cacheflow.worker.cache_engine import CacheEngine
|
||||||
|
from cacheflow.utils import get_gpu_memory
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
@ -25,9 +27,6 @@ class Worker:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
block_size: int,
|
|
||||||
num_gpu_blocks: int,
|
|
||||||
num_cpu_blocks: int,
|
|
||||||
dtype: str,
|
dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
@ -37,6 +36,7 @@ class Worker:
|
|||||||
use_dummy_weights: bool,
|
use_dummy_weights: bool,
|
||||||
use_np_cache: bool,
|
use_np_cache: bool,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
|
max_num_sequences: int,
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
pipeline_parallel_size: int = 1,
|
pipeline_parallel_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -46,8 +46,8 @@ class Worker:
|
|||||||
tensor_parallel_size,
|
tensor_parallel_size,
|
||||||
pipeline_parallel_size)
|
pipeline_parallel_size)
|
||||||
self.worker_id = rank
|
self.worker_id = rank
|
||||||
self.block_size = block_size
|
self.seed = seed
|
||||||
set_random_seed(seed)
|
set_random_seed(self.seed)
|
||||||
|
|
||||||
# Initialize the model.
|
# Initialize the model.
|
||||||
self.model, self.dtype = get_model(
|
self.model, self.dtype = get_model(
|
||||||
@ -55,8 +55,10 @@ class Worker:
|
|||||||
use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache)
|
use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache)
|
||||||
tensor_model_parallel_world_size = (
|
tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
initialize_all_reduce_launcher(
|
initialize_all_reduce_launcher(
|
||||||
max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
|
self.max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
|
||||||
|
self.max_num_sequences = max_num_sequences
|
||||||
self.num_layers = self.model.config.num_hidden_layers
|
self.num_layers = self.model.config.num_hidden_layers
|
||||||
assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
|
assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
|
||||||
self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
|
self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
|
||||||
@ -66,12 +68,80 @@ class Worker:
|
|||||||
# the random state is not affected by the model initialization.
|
# the random state is not affected by the model initialization.
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
|
|
||||||
|
# Uninitialized cache engine. Will be initialized with
|
||||||
|
# self.init_cache_engine().
|
||||||
|
self.block_size = None
|
||||||
|
self.cache_engine = None
|
||||||
|
self.cache_events = None
|
||||||
|
self.gpu_cache = None
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def get_num_available_blocks(
|
||||||
|
self, block_size: int, cpu_swap_space: int,
|
||||||
|
gpu_memory_utilization: float) -> Tuple[int, int]:
|
||||||
|
# Profile the memory usage of the model and get the maximum number of
|
||||||
|
# cache blocks that can be allocated with the remaining free memory.
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
# Profile memory usage with max_num_sequences sequences and the total
|
||||||
|
# number of tokens equal to max_num_batched_tokens.
|
||||||
|
|
||||||
|
# Enable top-k sampling to reflect the accurate memory usage.
|
||||||
|
sampling_params = SamplingParams(top_p=0.99,
|
||||||
|
top_k=self.model.config.vocab_size - 1)
|
||||||
|
seqs = []
|
||||||
|
for group_id in range(self.max_num_sequences):
|
||||||
|
seq_len = (self.max_num_batched_tokens // self.max_num_sequences +
|
||||||
|
(group_id < self.max_num_batched_tokens %
|
||||||
|
self.max_num_sequences))
|
||||||
|
seq_data = SequenceData([0] * seq_len)
|
||||||
|
seq = SequenceGroupMetadata(
|
||||||
|
group_id=group_id,
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={group_id: seq_data},
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
block_tables=None,
|
||||||
|
)
|
||||||
|
seqs.append(seq)
|
||||||
|
|
||||||
|
input_tokens, input_positions, input_metadata = self.prepare_inputs(seqs)
|
||||||
|
|
||||||
|
# Execute the model.
|
||||||
|
self.model(
|
||||||
|
input_ids=input_tokens,
|
||||||
|
positions=input_positions,
|
||||||
|
kv_caches=[(None, None)] * self.num_layers,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_events=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate the number of blocks that can be allocated with the
|
||||||
|
# profiled peak memory.
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
peak_memory = torch.cuda.max_memory_allocated()
|
||||||
|
total_gpu_memory = get_gpu_memory()
|
||||||
|
cache_block_size = get_cache_block_size(block_size, self.num_heads,
|
||||||
|
self.head_size, self.num_layers,
|
||||||
|
self.dtype)
|
||||||
|
num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization
|
||||||
|
- peak_memory) // cache_block_size)
|
||||||
|
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
# Reset the seed to ensure that the model output is not affected by
|
||||||
|
# the profiling.
|
||||||
|
set_random_seed(self.seed)
|
||||||
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def init_cache_engine(self, block_size: int, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int):
|
||||||
|
self.block_size = block_size
|
||||||
self.cache_engine = CacheEngine(
|
self.cache_engine = CacheEngine(
|
||||||
worker_id=self.worker_id,
|
worker_id=self.worker_id,
|
||||||
num_layers=self.num_layers,
|
num_layers=self.num_layers,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
block_size=block_size,
|
block_size=self.block_size,
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
num_cpu_blocks=num_cpu_blocks,
|
num_cpu_blocks=num_cpu_blocks,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
@ -129,6 +199,12 @@ class Worker:
|
|||||||
# is always the first token in the sequence.
|
# is always the first token in the sequence.
|
||||||
input_positions.extend(range(len(prompt_tokens)))
|
input_positions.extend(range(len(prompt_tokens)))
|
||||||
|
|
||||||
|
if seq_group_metadata.block_tables is None:
|
||||||
|
# During memory profiling, the block tables are not initialized
|
||||||
|
# yet. In this case, we just use a dummy slot mapping.
|
||||||
|
slot_mapping.extend([0] * prompt_len)
|
||||||
|
continue
|
||||||
|
|
||||||
# Compute the slot mapping.
|
# Compute the slot mapping.
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
for i in range(prompt_len):
|
for i in range(prompt_len):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user