166 lines
5.3 KiB
Python
166 lines
5.3 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
from transformers import AutoConfig, PretrainedConfig
|
|
|
|
|
|
class ModelConfig:
|
|
|
|
def __init__(
|
|
self,
|
|
model: str,
|
|
download_dir: Optional[str],
|
|
use_np_weights: bool,
|
|
use_dummy_weights: bool,
|
|
dtype: str,
|
|
seed: int,
|
|
) -> None:
|
|
self.model = model
|
|
self.download_dir = download_dir
|
|
self.use_np_weights = use_np_weights
|
|
self.use_dummy_weights = use_dummy_weights
|
|
self.seed = seed
|
|
|
|
self.hf_config: PretrainedConfig = AutoConfig.from_pretrained(model)
|
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
|
|
|
def verify_with_parallel_config(
|
|
self,
|
|
parallel_config: "ParallelConfig",
|
|
) -> None:
|
|
total_num_attention_heads = self.hf_config.num_attention_heads
|
|
tensor_parallel_size = parallel_config.tensor_parallel_size
|
|
if total_num_attention_heads % tensor_parallel_size != 0:
|
|
raise ValueError(
|
|
f"Total number of attention heads ({total_num_attention_heads})"
|
|
" must be divisible by tensor parallel size "
|
|
f"({tensor_parallel_size}).")
|
|
|
|
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
|
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
|
if total_num_hidden_layers % pipeline_parallel_size != 0:
|
|
raise ValueError(
|
|
f"Total number of hidden layers ({total_num_hidden_layers}) "
|
|
"must be divisible by pipeline parallel size "
|
|
f"({pipeline_parallel_size}).")
|
|
|
|
def get_hidden_size(self) -> int:
|
|
return self.hf_config.hidden_size
|
|
|
|
def get_head_size(self) -> int:
|
|
# FIXME(woosuk): This may not be true for all models.
|
|
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
|
|
|
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
|
total_num_attention_heads = self.hf_config.num_attention_heads
|
|
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
|
|
|
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
|
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
|
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
|
|
|
|
|
class CacheConfig:
|
|
|
|
def __init__(
|
|
self,
|
|
block_size: int,
|
|
gpu_memory_utilization: float,
|
|
swap_space: int,
|
|
) -> None:
|
|
self.block_size = block_size
|
|
self.gpu_memory_utilization = gpu_memory_utilization
|
|
self.swap_space = swap_space
|
|
|
|
# Will be set after profiling.
|
|
self.num_gpu_blocks = None
|
|
self.num_cpu_blocks = None
|
|
|
|
|
|
class ParallelConfig:
|
|
|
|
def __init__(
|
|
self,
|
|
pipeline_parallel_size: int,
|
|
tensor_parallel_size: int,
|
|
use_ray: bool,
|
|
) -> None:
|
|
self.pipeline_parallel_size = pipeline_parallel_size
|
|
self.tensor_parallel_size = tensor_parallel_size
|
|
self.use_ray = use_ray
|
|
|
|
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
|
if self.world_size > 1:
|
|
self.use_ray = True
|
|
self._verify_args()
|
|
|
|
def _verify_args(self) -> None:
|
|
if self.pipeline_parallel_size > 1:
|
|
raise NotImplementedError(
|
|
"Pipeline parallelism is not supported yet.")
|
|
|
|
|
|
class SchedulerConfig:
|
|
|
|
def __init__(
|
|
self,
|
|
max_num_batched_tokens: int,
|
|
max_num_seqs: int,
|
|
) -> None:
|
|
self.max_num_batched_tokens = max_num_batched_tokens
|
|
self.max_num_seqs = max_num_seqs
|
|
|
|
|
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
"half": torch.float16,
|
|
"float16": torch.float16,
|
|
"float": torch.float32,
|
|
"float32": torch.float32,
|
|
"bfloat16": torch.bfloat16,
|
|
}
|
|
|
|
|
|
def _get_and_verify_dtype(
|
|
config: PretrainedConfig,
|
|
dtype: str,
|
|
) -> torch.dtype:
|
|
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
|
# because config.torch_dtype can be None.
|
|
config_dtype = getattr(config, "torch_dtype", None)
|
|
if config_dtype is None:
|
|
config_dtype = torch.float32
|
|
|
|
dtype = dtype.lower()
|
|
if dtype == "default":
|
|
if config_dtype == torch.float32:
|
|
# Following the common practice, we use float16 for float32 models.
|
|
torch_dtype = torch.float16
|
|
else:
|
|
torch_dtype = config_dtype
|
|
else:
|
|
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
|
|
|
# Verify the dtype.
|
|
if torch_dtype != config_dtype:
|
|
if torch_dtype == torch.float32:
|
|
# Upcasting to float32 is allowed.
|
|
pass
|
|
elif config_dtype == torch.float32:
|
|
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
|
pass
|
|
else:
|
|
# Casting between float16 and bfloat16 is not allowed.
|
|
raise ValueError(
|
|
f"Cannot use {torch_dtype} for {config_dtype} model.")
|
|
|
|
# Check if the GPU supports the dtype.
|
|
if torch_dtype == torch.bfloat16:
|
|
compute_capability = torch.cuda.get_device_capability()
|
|
if compute_capability[0] < 8:
|
|
gpu_name = torch.cuda.get_device_name()
|
|
raise ValueError(
|
|
"Bfloat16 is only supported on GPUs with compute capability "
|
|
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
|
f"{compute_capability[0]}.{compute_capability[1]}.")
|
|
return torch_dtype
|