Refactor system architecture (#82)
This commit is contained in:
parent
8917782af6
commit
7c041ab578
@ -4,8 +4,8 @@ import pickle
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from cacheflow.master.block_manager import BlockSpaceManager
|
||||
from cacheflow.master.policy import PolicyFactory
|
||||
from cacheflow.core.block_manager import BlockSpaceManager
|
||||
from cacheflow.core.policy import PolicyFactory
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import Sequence
|
||||
from cacheflow.sequence import SequenceGroup
|
@ -8,20 +8,21 @@ try:
|
||||
except ImportError:
|
||||
ray = None
|
||||
|
||||
from cacheflow.core.scheduler import Scheduler
|
||||
from cacheflow.frontend.simple_frontend import SimpleFrontend
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.master.scheduler import Scheduler
|
||||
from cacheflow.master.simple_frontend import SimpleFrontend
|
||||
from cacheflow.models import get_memory_analyzer
|
||||
from cacheflow.worker.controller import Controller, DeviceID
|
||||
from cacheflow.model_executor import get_memory_analyzer
|
||||
from cacheflow.sequence import SequenceGroup
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
||||
from cacheflow.worker.controller import Controller, DeviceID
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Server:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
@ -1,22 +1,22 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Optional
|
||||
import json
|
||||
|
||||
import ray
|
||||
from transformers import AutoTokenizer
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
import ray
|
||||
from transformers import AutoTokenizer
|
||||
import uvicorn
|
||||
|
||||
from cacheflow.core.server import (Server, add_server_arguments,
|
||||
process_server_arguments,
|
||||
initialize_cluster)
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import Sequence, SequenceGroup
|
||||
from cacheflow.master.server import (Server, add_server_arguments,
|
||||
process_server_arguments,
|
||||
initialize_cluster)
|
||||
from cacheflow.worker.controller import DeviceID
|
||||
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
|
||||
from cacheflow.worker.controller import DeviceID
|
||||
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
app = FastAPI()
|
11
cacheflow/model_executor/__init__.py
Normal file
11
cacheflow/model_executor/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.model_loader import get_model, get_memory_analyzer
|
||||
from cacheflow.model_executor.utils import set_random_seed
|
||||
|
||||
|
||||
__all__ = [
|
||||
"InputMetadata",
|
||||
"get_model",
|
||||
"get_memory_analyzer",
|
||||
"set_random_seed",
|
||||
]
|
@ -7,7 +7,7 @@ from xformers import ops as xops
|
||||
from cacheflow import attention_ops
|
||||
from cacheflow import cache_ops
|
||||
from cacheflow import pos_encoding_ops
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
|
||||
|
||||
class GPTCacheFlowAttention(nn.Module):
|
@ -3,10 +3,11 @@ from typing import Dict, List, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
||||
gather_from_tensor_model_parallel_region)
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_parallel_region
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
@ -27,7 +28,7 @@ class Sampler(nn.Module):
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
logits = gather_from_tensor_model_parallel_region(logits)
|
||||
# Remove paddings in vocab.
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :self.vocab_size]
|
||||
|
||||
# Apply temperature scaling.
|
@ -2,7 +2,7 @@ import torch
|
||||
from transformers import AutoConfig
|
||||
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.models.utils import get_dtype_size
|
||||
from cacheflow.model_executor.utils import get_dtype_size
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
@ -5,16 +5,13 @@ import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import GPTNeoXMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
|
||||
from cacheflow.models.gpt2 import GPT2LMHeadModel
|
||||
from cacheflow.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from cacheflow.models.llama import LlamaForCausalLM
|
||||
from cacheflow.models.opt import OPTForCausalLM
|
||||
from cacheflow.models.utils import get_torch_dtype
|
||||
from cacheflow.model_executor.memory_analyzer import (
|
||||
CacheFlowMemoryAnalyzer, GPT2MemoryAnalyzer, GPTNeoXMemoryAnalyzer,
|
||||
LlamaMemoryAnalyzer, OPTMemoryAnalyzer)
|
||||
from cacheflow.model_executor.models import (
|
||||
GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
|
||||
from cacheflow.model_executor.utils import get_torch_dtype
|
||||
from cacheflow.model_executor.weight_utils import initialize_dummy_weights
|
||||
|
||||
|
||||
_MODELS = {
|
||||
@ -77,7 +74,7 @@ def get_model(
|
||||
model = model.cuda()
|
||||
# NOTE(woosuk): For precise performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
model.initialize_dummy_weights()
|
||||
initialize_dummy_weights(model)
|
||||
else:
|
||||
# Create a model instance.
|
||||
model = model_class(config)
|
12
cacheflow/model_executor/models/__init__.py
Normal file
12
cacheflow/model_executor/models/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
from cacheflow.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from cacheflow.model_executor.models.gpt2 import GPT2LMHeadModel
|
||||
from cacheflow.model_executor.models.llama import LlamaForCausalLM
|
||||
from cacheflow.model_executor.models.opt import OPTForCausalLM
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GPT2LMHeadModel",
|
||||
"GPTNeoXForCausalLM",
|
||||
"LlamaForCausalLM",
|
||||
"OPTForCausalLM",
|
||||
]
|
@ -5,16 +5,15 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import GPT2Config
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.attention import GPTCacheFlowAttention
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -258,8 +257,5 @@ class GPT2LMHeadModel(nn.Module):
|
||||
raise ValueError(f"Unexpected parameter name {name}")
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights)
|
||||
|
||||
def initialize_dummy_weights(self) -> None:
|
||||
for param in self.state_dict().values():
|
||||
param.data.uniform_(-1e-3, 1e-3)
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
@ -3,17 +3,17 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPTNeoXConfig
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -21,7 +21,7 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
class GPTNeoXAttention(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
super().__init__()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -63,7 +63,7 @@ class GPTNeoXAttention(nn.Module):
|
||||
|
||||
|
||||
class GPTNeoXMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
super().__init__()
|
||||
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
@ -86,7 +86,7 @@ class GPTNeoXMLP(nn.Module):
|
||||
|
||||
class GPTNeoXLayer(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@ -129,7 +129,7 @@ class GPTNeoXLayer(nn.Module):
|
||||
|
||||
|
||||
class GPTNeoXModel(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@ -227,8 +227,5 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
raise ValueError(f"Unexpected weight name: {name}")
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights)
|
||||
|
||||
def initialize_dummy_weights(self) -> None:
|
||||
for param in self.state_dict().values():
|
||||
param.data.uniform_(-1e-3, 1e-3)
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
@ -5,18 +5,18 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.activation import SiluAndMul
|
||||
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
|
||||
from cacheflow.models.layernorm import RMSNorm
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.activation import SiluAndMul
|
||||
from cacheflow.model_executor.layers.layernorm import RMSNorm
|
||||
from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -263,8 +263,5 @@ class LlamaForCausalLM(nn.Module):
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights)
|
||||
|
||||
def initialize_dummy_weights(self) -> None:
|
||||
for param in self.state_dict().values():
|
||||
param.data.uniform_(-1e-3, 1e-3)
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
@ -5,16 +5,15 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import OPTConfig
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.attention import GPTCacheFlowAttention
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -288,8 +287,5 @@ class OPTForCausalLM(nn.Module):
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights)
|
||||
|
||||
def initialize_dummy_weights(self) -> None:
|
||||
for param in self.state_dict().values():
|
||||
param.data.uniform_(-1e-3, 1e-3)
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
12
cacheflow/model_executor/parallel_utils/__init__.py
Normal file
12
cacheflow/model_executor/parallel_utils/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
import cacheflow.model_executor.parallel_utils.parallel_state
|
||||
import cacheflow.model_executor.parallel_utils.tensor_parallel
|
||||
import cacheflow.model_executor.parallel_utils.utils
|
||||
|
||||
# Alias parallel_state as mpu, its legacy name
|
||||
mpu = parallel_state
|
||||
|
||||
__all__ = [
|
||||
"parallel_state",
|
||||
"tensor_parallel",
|
||||
"utils",
|
||||
]
|
@ -9,7 +9,7 @@ import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_all_reduce_launcher,
|
@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_group,
|
@ -10,7 +10,7 @@ from torch import _C
|
||||
from torch.cuda import _lazy_call, device as device_ctx_manager
|
||||
from torch.utils.checkpoint import detach_variable
|
||||
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
get_data_parallel_rank,
|
||||
get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
@ -22,7 +22,7 @@ from .utils import (
|
||||
gather_split_1d_tensor,
|
||||
)
|
||||
|
||||
from cacheflow.parallel_utils.utils import safely_set_viewless_tensor_data
|
||||
from cacheflow.model_executor.parallel_utils.utils import safely_set_viewless_tensor_data
|
||||
|
||||
# Default name for the model parallel rng tracker.
|
||||
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
|
@ -3,8 +3,8 @@
|
||||
import torch
|
||||
from typing import List, Sequence
|
||||
|
||||
from cacheflow.parallel_utils.utils import divide
|
||||
from cacheflow.parallel_utils import parallel_state
|
||||
from cacheflow.model_executor.parallel_utils.utils import divide
|
||||
from cacheflow.model_executor.parallel_utils import parallel_state
|
||||
|
||||
def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor,
|
@ -4,7 +4,7 @@ import operator
|
||||
|
||||
import torch
|
||||
|
||||
from cacheflow.parallel_utils import parallel_state
|
||||
from cacheflow.model_executor.parallel_utils import parallel_state
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
41
cacheflow/model_executor/utils.py
Normal file
41
cacheflow/model_executor/utils.py
Normal file
@ -0,0 +1,41 @@
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import model_parallel_is_initialized
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
'half': torch.half,
|
||||
'float': torch.float,
|
||||
'float16': torch.float16,
|
||||
'float32': torch.float32,
|
||||
'bfloat16': torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
|
||||
if isinstance(dtype, str):
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
|
||||
else:
|
||||
torch_dtype = dtype
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
|
||||
torch_dtype = get_torch_dtype(dtype)
|
||||
return torch.tensor([], dtype=torch_dtype).element_size()
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if model_parallel_is_initialized():
|
||||
model_parallel_cuda_manual_seed(seed)
|
@ -1,47 +1,26 @@
|
||||
import os
|
||||
import filelock
|
||||
import glob
|
||||
import json
|
||||
import filelock
|
||||
from typing import Union, Optional
|
||||
import os
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank)
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
'half': torch.half,
|
||||
'float': torch.float,
|
||||
'float16': torch.float16,
|
||||
'float32': torch.float32,
|
||||
'bfloat16': torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
|
||||
if isinstance(dtype, str):
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
|
||||
else:
|
||||
torch_dtype = dtype
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
|
||||
torch_dtype = get_torch_dtype(dtype)
|
||||
return torch.tensor([], dtype=torch_dtype).element_size()
|
||||
|
||||
|
||||
class Disabledtqdm(tqdm):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, disable=True)
|
||||
|
||||
|
||||
def hf_model_weights_iterator(model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
def hf_model_weights_iterator(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False,
|
||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
# Prepare file lock directory to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
lock_dir = cache_dir if cache_dir is not None else "/tmp"
|
||||
@ -95,10 +74,14 @@ def hf_model_weights_iterator(model_name_or_path: str,
|
||||
yield name, param
|
||||
|
||||
|
||||
def load_tensor_parallel_weights(param, loaded_weight, param_name,
|
||||
column_parallel_weight_names,
|
||||
row_parallel_weight_names):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
def load_tensor_parallel_weights(
|
||||
param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
param_name: str,
|
||||
column_parallel_weight_names: List[str],
|
||||
row_parallel_weight_names: List[str],
|
||||
tensor_model_parallel_rank: int,
|
||||
) -> None:
|
||||
for p in column_parallel_weight_names:
|
||||
if p in param_name:
|
||||
shard_size = param.shape[0]
|
||||
@ -116,3 +99,12 @@ def load_tensor_parallel_weights(param, loaded_weight, param_name,
|
||||
break
|
||||
assert param.shape == loaded_weight.shape
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
def initialize_dummy_weights(
|
||||
model: torch.nn.Module,
|
||||
low: float = -1e-3,
|
||||
high: float = 1e-3,
|
||||
) -> None:
|
||||
for param in model.state_dict().values():
|
||||
param.data.uniform_(low, high)
|
@ -1,10 +0,0 @@
|
||||
from cacheflow.models.input_metadata import InputMetadata
|
||||
from cacheflow.models.model_utils import get_memory_analyzer
|
||||
from cacheflow.models.model_utils import get_model
|
||||
|
||||
|
||||
__all__ = [
|
||||
'InputMetadata',
|
||||
'get_memory_analyzer',
|
||||
'get_model',
|
||||
]
|
@ -1,12 +0,0 @@
|
||||
import cacheflow.parallel_utils.parallel_state
|
||||
import cacheflow.parallel_utils.tensor_parallel
|
||||
import cacheflow.parallel_utils.utils
|
||||
|
||||
# Alias parallel_state as mpu, its legacy name
|
||||
mpu = parallel_state
|
||||
|
||||
__all__ = [
|
||||
"parallel_state",
|
||||
"tensor_parallel",
|
||||
"utils",
|
||||
]
|
@ -1,13 +1,8 @@
|
||||
import enum
|
||||
import random
|
||||
|
||||
import psutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from cacheflow.parallel_utils.parallel_state import model_parallel_is_initialized
|
||||
from cacheflow.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
|
||||
|
||||
|
||||
class Device(enum.Enum):
|
||||
GPU = enum.auto()
|
||||
@ -28,17 +23,6 @@ class Counter:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
def set_random_seed(seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if model_parallel_is_initialized():
|
||||
model_parallel_cuda_manual_seed(seed)
|
||||
|
||||
|
||||
def get_gpu_memory(gpu: int = 0) -> int:
|
||||
return torch.cuda.get_device_properties(gpu).total_memory
|
||||
|
||||
|
@ -5,7 +5,7 @@ try:
|
||||
except ImportError:
|
||||
ray = None
|
||||
|
||||
from cacheflow.master.scheduler import Scheduler
|
||||
from cacheflow.core.scheduler import Scheduler
|
||||
from cacheflow.sequence import SequenceGroupInputs
|
||||
from cacheflow.worker.worker import Worker
|
||||
|
||||
|
@ -2,18 +2,15 @@ from typing import Dict, List, Tuple, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from cacheflow.models import get_model
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.model_executor import get_model, InputMetadata, set_random_seed
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
initialize_model_parallel,
|
||||
initialize_all_reduce_launcher,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import SequenceGroupInputs
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from cacheflow.worker.cache_engine import CacheEngine
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
initialize_model_parallel,
|
||||
initialize_all_reduce_launcher,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from cacheflow.utils import set_random_seed
|
||||
|
||||
|
||||
class Worker:
|
||||
|
||||
|
@ -1,212 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
SYSTEMS = [
|
||||
'orca-constant',
|
||||
'orca-power2',
|
||||
'orca-oracle',
|
||||
'cacheflow',
|
||||
]
|
||||
|
||||
SYSTEM_TO_LABEL = {
|
||||
'orca-constant': 'Orca (Max)',
|
||||
'orca-power2': 'Orca (Pow2)',
|
||||
'orca-oracle': 'Orca (Oracle)',
|
||||
'cacheflow': 'KVFlow',
|
||||
}
|
||||
|
||||
SYSTEM_TO_COLOR = {
|
||||
'orca-constant': 'red',
|
||||
'orca-power2': 'orange',
|
||||
'orca-oracle': 'green',
|
||||
'cacheflow': 'blue',
|
||||
}
|
||||
|
||||
SYSTEM_TO_MARKER = {
|
||||
'orca-constant': 'x',
|
||||
'orca-power2': '^',
|
||||
'orca-oracle': 's',
|
||||
'cacheflow': 'o',
|
||||
}
|
||||
|
||||
|
||||
def get_results(save_dir: str) -> List[Dict[str, Any]]:
|
||||
with open(os.path.join(save_dir, 'sequences.pkl'), 'rb') as f:
|
||||
results = pickle.load(f)
|
||||
return results
|
||||
|
||||
|
||||
def get_request_rate(save_dir: str) -> float:
|
||||
"""Get request rate from save_dir name."""
|
||||
# Directory name format:
|
||||
# .../req-rate-{req_rate}/seed-{seed}/duration-{duration}
|
||||
save_dir = os.path.abspath(save_dir)
|
||||
dir_names = save_dir.split('/')
|
||||
|
||||
request_rate = None
|
||||
for dir_name in dir_names:
|
||||
if dir_name.startswith('req-rate-'):
|
||||
if request_rate is not None:
|
||||
raise ValueError(f'Found multiple request rates in {save_dir}')
|
||||
request_rate = float(dir_name.split('-')[-1])
|
||||
if request_rate is None:
|
||||
raise ValueError(f'Cannot find request rate in {save_dir}')
|
||||
return request_rate
|
||||
|
||||
|
||||
def get_model(save_dir: str) -> Tuple[str, int]:
|
||||
save_dir = os.path.abspath(save_dir)
|
||||
dir_names = save_dir.split('/')
|
||||
|
||||
model = None
|
||||
for dir_name in dir_names:
|
||||
if '-tp' in dir_name:
|
||||
if model is not None:
|
||||
raise ValueError(f'Found multiple models in {save_dir}')
|
||||
model = dir_name.split('-tp')[0]
|
||||
tp = int(dir_name.split('-tp')[-1])
|
||||
if model is None:
|
||||
raise ValueError(f'Cannot find model in {save_dir}')
|
||||
return model, tp
|
||||
|
||||
|
||||
def get_system(save_dir: str) -> str:
|
||||
save_dir = os.path.abspath(save_dir)
|
||||
dir_names = save_dir.split('/')
|
||||
|
||||
for dir_name in dir_names:
|
||||
if dir_name.startswith('orca-'):
|
||||
return dir_name
|
||||
if dir_name == 'cacheflow':
|
||||
return dir_name
|
||||
raise ValueError(f'Cannot find system in {save_dir}')
|
||||
|
||||
|
||||
def get_sampling(save_dir: str) -> str:
|
||||
save_dir = os.path.abspath(save_dir)
|
||||
dir_names = save_dir.split('/')
|
||||
|
||||
for dir_name in dir_names:
|
||||
if dir_name.startswith('n'):
|
||||
if dir_name.endswith('-beam'):
|
||||
return dir_name
|
||||
if dir_name[1:].isdigit():
|
||||
return dir_name
|
||||
raise ValueError(f'Cannot find sampling method in {save_dir}')
|
||||
|
||||
|
||||
def plot_normalized_latency(
|
||||
exp_dir: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
warmup: int,
|
||||
xlim: Optional[float],
|
||||
ylim: Optional[float],
|
||||
log_scale: bool,
|
||||
format: str,
|
||||
) -> None:
|
||||
# Get leaf directories.
|
||||
save_dirs = []
|
||||
for root, dirs, files in os.walk(exp_dir):
|
||||
if dirs:
|
||||
continue
|
||||
if 'sequences.pkl' not in files:
|
||||
continue
|
||||
if f'seed{seed}' not in root:
|
||||
continue
|
||||
if f'duration-{duration}' not in root:
|
||||
continue
|
||||
save_dirs.append(root)
|
||||
|
||||
# Plot normalized latency.
|
||||
perf_per_system: Dict[str, Tuple[List[float], List[float]]] = {}
|
||||
for save_dir in save_dirs:
|
||||
per_seq_norm_latencies = []
|
||||
results = get_results(save_dir)
|
||||
for seq in results:
|
||||
arrival_time = seq['arrival_time']
|
||||
finish_time = seq['finish_time']
|
||||
output_len = seq['output_len']
|
||||
if arrival_time < warmup:
|
||||
continue
|
||||
latency = finish_time - arrival_time
|
||||
norm_latency = latency / output_len
|
||||
per_seq_norm_latencies.append(norm_latency)
|
||||
|
||||
request_rate = get_request_rate(save_dir)
|
||||
normalized_latency = np.mean(per_seq_norm_latencies)
|
||||
system_name = get_system(save_dir)
|
||||
if system_name not in perf_per_system:
|
||||
perf_per_system[system_name] = ([], [])
|
||||
perf_per_system[system_name][0].append(request_rate)
|
||||
perf_per_system[system_name][1].append(normalized_latency)
|
||||
|
||||
print('#seqs', len(per_seq_norm_latencies))
|
||||
print(f'{save_dir}: {normalized_latency:.3f} s')
|
||||
|
||||
|
||||
# Plot normalized latency.
|
||||
plt.figure(figsize=(6, 4))
|
||||
for system_name in reversed(SYSTEMS):
|
||||
if system_name not in perf_per_system:
|
||||
continue
|
||||
# Sort by request rate.
|
||||
request_rates, normalized_latencies = perf_per_system[system_name]
|
||||
request_rates, normalized_latencies = zip(*sorted(zip(request_rates, normalized_latencies)))
|
||||
label = SYSTEM_TO_LABEL[system_name]
|
||||
color = SYSTEM_TO_COLOR[system_name]
|
||||
marker = SYSTEM_TO_MARKER[system_name]
|
||||
plt.plot(request_rates, normalized_latencies, label=label, color=color, marker=marker)
|
||||
|
||||
# plt.legend()
|
||||
plt.xlabel('Request rate (req/s)', fontsize=12)
|
||||
plt.ylabel('Normalized latency (s/token)', fontsize=12)
|
||||
|
||||
if log_scale:
|
||||
plt.yscale('log')
|
||||
if xlim is not None:
|
||||
plt.xlim(left=0, right=xlim)
|
||||
if ylim is not None:
|
||||
if log_scale:
|
||||
plt.ylim(top=ylim)
|
||||
else:
|
||||
plt.ylim(bottom=0, top=ylim)
|
||||
|
||||
handles, labels = plt.gca().get_legend_handles_labels()
|
||||
handles = reversed(handles)
|
||||
labels = reversed(labels)
|
||||
|
||||
plt.legend(
|
||||
handles, labels,
|
||||
ncol=4, fontsize=12, loc='upper center', bbox_to_anchor=(0.5, 1.15),
|
||||
columnspacing=0.5, handletextpad=0.5, handlelength=1.5, frameon=False, borderpad=0)
|
||||
|
||||
# Save figure.
|
||||
model, tp = get_model(exp_dir)
|
||||
sampling = get_sampling(exp_dir)
|
||||
figname = f'{model}-tp{tp}-{sampling}.{format}'
|
||||
os.makedirs('./figures', exist_ok=True)
|
||||
plt.savefig(os.path.join('figures', figname), bbox_inches='tight')
|
||||
print(f'Saved figure to ./figures/{figname}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('exp_dir', type=str)
|
||||
parser.add_argument('--duration', type=int, required=True)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--warmup', type=int, default=60)
|
||||
parser.add_argument('--xlim', type=float, required=False, default=None)
|
||||
parser.add_argument('--ylim', type=float, required=False, default=None)
|
||||
parser.add_argument('--log', action='store_true')
|
||||
parser.add_argument('--format', choices=['png', 'pdf'], default='png')
|
||||
args = parser.parse_args()
|
||||
|
||||
plot_normalized_latency(
|
||||
args.exp_dir, args.duration, args.seed, args.warmup, args.xlim, args.ylim, args.log, args.format)
|
@ -1,52 +0,0 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
STAT_NAMES = [
|
||||
'input_lens',
|
||||
'num_running',
|
||||
'num_waiting',
|
||||
'num_preemption',
|
||||
'gpu_cache_usage',
|
||||
'cpu_cache_usage',
|
||||
'num_swapped',
|
||||
'swap_in_lens',
|
||||
'swap_out_lens',
|
||||
]
|
||||
|
||||
|
||||
def plot_stats(output_dir: str):
|
||||
# Get stats.
|
||||
with open(os.path.join(output_dir, 'stats.pkl'), 'rb') as f:
|
||||
stats = pickle.load(f)
|
||||
timestamps = stats['timestamps']
|
||||
|
||||
# Draw one figure for each stat.
|
||||
num_stats = len(STAT_NAMES)
|
||||
COLORS = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'purple', 'pink', 'brown', 'gray']
|
||||
fig, axs = plt.subplots(num_stats, 1, figsize=(10, 2 * num_stats))
|
||||
for i, stat in enumerate(STAT_NAMES):
|
||||
data = stats[stat]
|
||||
if stat in ['gpu_cache_usage', 'cpu_cache_usage']:
|
||||
data = [x * 100 for x in data]
|
||||
stat = stat + ' (%)'
|
||||
axs[i].plot(timestamps, data, color=COLORS[i % len(COLORS)])
|
||||
axs[i].set_ylabel(stat.replace('_', ' '), fontdict={'fontsize': 12})
|
||||
axs[i].set_ylim(bottom=0)
|
||||
|
||||
plt.xlabel('Time (s)')
|
||||
plt.tight_layout()
|
||||
fig_path = os.path.join(output_dir, 'stats.png')
|
||||
plt.savefig(fig_path)
|
||||
print(f'Saved stats to {fig_path}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('output_dir', type=str, help='Output directory.')
|
||||
args = parser.parse_args()
|
||||
|
||||
plot_stats(args.output_dir)
|
@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
|
||||
from cacheflow.master.server import (
|
||||
from cacheflow.core.server import (
|
||||
add_server_arguments, process_server_arguments,
|
||||
init_local_server_and_frontend_with_arguments)
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
|
Loading…
x
Reference in New Issue
Block a user