2023-05-20 13:06:59 -07:00
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
import torch
|
2023-07-03 16:47:53 -07:00
|
|
|
from transformers import PretrainedConfig
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
from vllm.logger import init_logger
|
2023-07-03 16:47:53 -07:00
|
|
|
from vllm.transformers_utils.config import get_config
|
2023-06-17 03:07:40 -07:00
|
|
|
from vllm.utils import get_cpu_memory
|
2023-05-23 18:22:26 -07:00
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
2023-07-03 11:31:55 -07:00
|
|
|
_GB = 1 << 30
|
2023-05-21 17:04:18 -07:00
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
|
|
|
|
class ModelConfig:
|
2023-06-07 18:25:20 +08:00
|
|
|
"""Configuration for the model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model: Name or path of the huggingface model to use.
|
2023-06-28 09:46:58 -07:00
|
|
|
tokenizer: Name or path of the huggingface tokenizer to use.
|
2023-06-28 14:19:22 -07:00
|
|
|
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
|
|
|
available, and "slow" will always use the slow tokenizer.
|
2023-07-07 20:04:58 +02:00
|
|
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
|
|
|
downloading the model and tokenizer.
|
2023-06-07 18:25:20 +08:00
|
|
|
download_dir: Directory to download and load the weights, default to the
|
|
|
|
default cache directory of huggingface.
|
2023-09-07 15:49:52 -07:00
|
|
|
load_format: The format of the model weights to load:
|
|
|
|
"auto" will try to load the weights in the safetensors format and
|
|
|
|
fall back to the pytorch bin format if safetensors format is
|
|
|
|
not available.
|
|
|
|
"pt" will load the weights in the pytorch bin format.
|
|
|
|
"safetensors" will load the weights in the safetensors format.
|
|
|
|
"npcache" will load the weights in pytorch format and store
|
|
|
|
a numpy cache to speed up the loading.
|
|
|
|
"dummy" will initialize the weights with random values, which is
|
|
|
|
mainly for profiling.
|
2023-06-07 18:25:20 +08:00
|
|
|
dtype: Data type for model weights and activations. The "auto" option
|
|
|
|
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
|
|
|
for BF16 models.
|
|
|
|
seed: Random seed for reproducibility.
|
2023-09-14 06:20:02 +08:00
|
|
|
revision: The specific model version to use. It can be a branch name,
|
|
|
|
a tag name, or a commit id. If unspecified, will use the default
|
|
|
|
version.
|
2023-09-12 16:29:19 -07:00
|
|
|
max_model_len: Maximum length of a sequence (including prompt and
|
|
|
|
output). If None, will be derived from the model.
|
2023-09-16 00:03:37 -07:00
|
|
|
quantization: Quantization method that was used to quantize the model
|
|
|
|
weights. If None, we assume the model weights are not quantized.
|
2023-06-07 18:25:20 +08:00
|
|
|
"""
|
2023-05-20 13:06:59 -07:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model: str,
|
2023-06-28 14:19:22 -07:00
|
|
|
tokenizer: str,
|
|
|
|
tokenizer_mode: str,
|
2023-07-07 20:04:58 +02:00
|
|
|
trust_remote_code: bool,
|
2023-05-20 13:06:59 -07:00
|
|
|
download_dir: Optional[str],
|
2023-09-07 15:49:52 -07:00
|
|
|
load_format: str,
|
2023-05-20 13:06:59 -07:00
|
|
|
dtype: str,
|
|
|
|
seed: int,
|
2023-09-14 06:20:02 +08:00
|
|
|
revision: Optional[str],
|
2023-09-12 16:29:19 -07:00
|
|
|
max_model_len: Optional[int] = None,
|
2023-09-16 00:03:37 -07:00
|
|
|
quantization: Optional[str] = None,
|
2023-05-20 13:06:59 -07:00
|
|
|
) -> None:
|
|
|
|
self.model = model
|
2023-06-28 09:46:58 -07:00
|
|
|
self.tokenizer = tokenizer
|
2023-06-28 14:19:22 -07:00
|
|
|
self.tokenizer_mode = tokenizer_mode
|
2023-07-07 20:04:58 +02:00
|
|
|
self.trust_remote_code = trust_remote_code
|
2023-05-20 13:06:59 -07:00
|
|
|
self.download_dir = download_dir
|
2023-09-07 15:49:52 -07:00
|
|
|
self.load_format = load_format
|
2023-05-20 13:06:59 -07:00
|
|
|
self.seed = seed
|
2023-09-14 06:20:02 +08:00
|
|
|
self.revision = revision
|
2023-09-16 00:03:37 -07:00
|
|
|
self.quantization = quantization
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2023-09-14 06:20:02 +08:00
|
|
|
self.hf_config = get_config(model, trust_remote_code, revision)
|
2023-05-20 13:06:59 -07:00
|
|
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
2023-09-07 15:49:52 -07:00
|
|
|
self._verify_load_format()
|
2023-06-28 14:19:22 -07:00
|
|
|
self._verify_tokenizer_mode()
|
2023-09-16 00:03:37 -07:00
|
|
|
self._verify_quantization()
|
2023-09-12 16:29:19 -07:00
|
|
|
self.max_model_len = None
|
|
|
|
if max_model_len is not None:
|
|
|
|
derived_max_model_len = self.get_max_model_len()
|
|
|
|
if max_model_len > derived_max_model_len:
|
|
|
|
logger.warning(
|
|
|
|
f"User-specified max_model_len ({max_model_len}) is "
|
|
|
|
f"greater than the derived max_model_len "
|
|
|
|
f"({derived_max_model_len}). Make sure the value is "
|
|
|
|
"correct and within the model context size.")
|
|
|
|
self.max_model_len = max_model_len
|
2023-06-28 14:19:22 -07:00
|
|
|
|
2023-09-07 15:49:52 -07:00
|
|
|
def _verify_load_format(self) -> None:
|
|
|
|
load_format = self.load_format.lower()
|
|
|
|
if load_format not in [
|
|
|
|
"auto", "pt", "safetensors", "npcache", "dummy"
|
|
|
|
]:
|
|
|
|
raise ValueError(
|
|
|
|
f"Unknown load format: {self.load_format}. Must be one of "
|
|
|
|
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
|
|
|
self.load_format = load_format
|
|
|
|
|
2023-06-28 14:19:22 -07:00
|
|
|
def _verify_tokenizer_mode(self) -> None:
|
|
|
|
tokenizer_mode = self.tokenizer_mode.lower()
|
|
|
|
if tokenizer_mode not in ["auto", "slow"]:
|
|
|
|
raise ValueError(
|
|
|
|
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
|
|
|
"either 'auto' or 'slow'.")
|
|
|
|
self.tokenizer_mode = tokenizer_mode
|
2023-05-20 13:06:59 -07:00
|
|
|
|
2023-09-16 00:03:37 -07:00
|
|
|
def _verify_quantization(self) -> None:
|
|
|
|
supported_quantization = ["awq"]
|
|
|
|
if self.quantization is None:
|
|
|
|
return
|
|
|
|
quantization = self.quantization.lower()
|
|
|
|
if quantization not in supported_quantization:
|
|
|
|
raise ValueError(
|
|
|
|
f"Unknown quantization: {self.quantization}. Must be one of "
|
|
|
|
f"{supported_quantization}.")
|
|
|
|
self.quantization = quantization
|
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
def verify_with_parallel_config(
|
|
|
|
self,
|
|
|
|
parallel_config: "ParallelConfig",
|
|
|
|
) -> None:
|
|
|
|
total_num_attention_heads = self.hf_config.num_attention_heads
|
|
|
|
tensor_parallel_size = parallel_config.tensor_parallel_size
|
|
|
|
if total_num_attention_heads % tensor_parallel_size != 0:
|
|
|
|
raise ValueError(
|
|
|
|
f"Total number of attention heads ({total_num_attention_heads})"
|
|
|
|
" must be divisible by tensor parallel size "
|
|
|
|
f"({tensor_parallel_size}).")
|
|
|
|
|
|
|
|
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
|
|
|
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
|
|
|
if total_num_hidden_layers % pipeline_parallel_size != 0:
|
|
|
|
raise ValueError(
|
|
|
|
f"Total number of hidden layers ({total_num_hidden_layers}) "
|
|
|
|
"must be divisible by pipeline parallel size "
|
|
|
|
f"({pipeline_parallel_size}).")
|
|
|
|
|
|
|
|
def get_hidden_size(self) -> int:
|
|
|
|
return self.hf_config.hidden_size
|
|
|
|
|
|
|
|
def get_head_size(self) -> int:
|
|
|
|
# FIXME(woosuk): This may not be true for all models.
|
|
|
|
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
|
|
|
|
|
|
|
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
2023-08-02 14:04:39 -07:00
|
|
|
# For GPTBigCode & Falcon:
|
|
|
|
# Note: for falcon, when new_decoder_architecture is True, the
|
|
|
|
# multi_query flag is ignored and we use n_head_kv for the number of
|
|
|
|
# KV heads.
|
2023-09-10 17:39:02 +09:00
|
|
|
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
2023-08-05 01:35:22 +08:00
|
|
|
new_decoder_arch_falcon = (
|
2023-09-10 17:39:02 +09:00
|
|
|
self.hf_config.model_type in falcon_model_types
|
2023-08-05 01:35:22 +08:00
|
|
|
and getattr(self.hf_config, "new_decoder_architecture", False))
|
|
|
|
if not new_decoder_arch_falcon and getattr(self.hf_config,
|
|
|
|
"multi_query", False):
|
2023-07-14 20:06:40 -04:00
|
|
|
# Multi-query attention, only one KV head.
|
|
|
|
return 1
|
|
|
|
# For Falcon:
|
|
|
|
if getattr(self.hf_config, "n_head_kv", None) is not None:
|
2023-07-20 11:38:27 -07:00
|
|
|
return (self.hf_config.n_head_kv //
|
|
|
|
parallel_config.tensor_parallel_size)
|
|
|
|
# For LLaMA-2:
|
|
|
|
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
|
|
|
|
return (self.hf_config.num_key_value_heads //
|
|
|
|
parallel_config.tensor_parallel_size)
|
2023-05-20 13:06:59 -07:00
|
|
|
total_num_attention_heads = self.hf_config.num_attention_heads
|
|
|
|
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
|
|
|
|
2023-07-25 23:46:30 -07:00
|
|
|
def get_max_model_len(self) -> int:
|
2023-09-12 16:29:19 -07:00
|
|
|
if self.max_model_len is not None:
|
|
|
|
return self.max_model_len
|
2023-07-25 23:46:30 -07:00
|
|
|
max_model_len = float("inf")
|
|
|
|
possible_keys = [
|
|
|
|
# OPT
|
|
|
|
"max_position_embeddings",
|
|
|
|
# GPT-2
|
|
|
|
"n_positions",
|
|
|
|
# MPT
|
|
|
|
"max_seq_len",
|
|
|
|
# Others
|
|
|
|
"max_sequence_length",
|
|
|
|
"max_seq_length",
|
|
|
|
"seq_len",
|
|
|
|
]
|
|
|
|
for key in possible_keys:
|
|
|
|
max_len_key = getattr(self.hf_config, key, None)
|
|
|
|
if max_len_key is not None:
|
|
|
|
max_model_len = min(max_model_len, max_len_key)
|
|
|
|
return max_model_len
|
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
|
|
|
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
|
|
|
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
|
|
|
|
|
|
|
|
|
|
|
class CacheConfig:
|
2023-06-07 18:25:20 +08:00
|
|
|
"""Configuration for the KV cache.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
block_size: Size of a cache block in number of tokens.
|
|
|
|
gpu_memory_utilization: Fraction of GPU memory to use for the
|
2023-06-17 03:07:40 -07:00
|
|
|
vLLM execution.
|
2023-06-07 18:25:20 +08:00
|
|
|
swap_space: Size of the CPU swap space per GPU (in GiB).
|
|
|
|
"""
|
2023-07-03 11:31:55 -07:00
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
block_size: int,
|
|
|
|
gpu_memory_utilization: float,
|
|
|
|
swap_space: int,
|
|
|
|
) -> None:
|
|
|
|
self.block_size = block_size
|
|
|
|
self.gpu_memory_utilization = gpu_memory_utilization
|
2023-07-03 11:31:55 -07:00
|
|
|
self.swap_space_bytes = swap_space * _GB
|
2023-05-23 18:22:26 -07:00
|
|
|
self._verify_args()
|
2023-05-20 13:06:59 -07:00
|
|
|
|
|
|
|
# Will be set after profiling.
|
|
|
|
self.num_gpu_blocks = None
|
|
|
|
self.num_cpu_blocks = None
|
|
|
|
|
2023-05-23 18:22:26 -07:00
|
|
|
def _verify_args(self) -> None:
|
|
|
|
if self.gpu_memory_utilization > 1.0:
|
|
|
|
raise ValueError(
|
|
|
|
"GPU memory utilization must be less than 1.0. Got "
|
|
|
|
f"{self.gpu_memory_utilization}.")
|
|
|
|
|
|
|
|
def verify_with_parallel_config(
|
|
|
|
self,
|
|
|
|
parallel_config: "ParallelConfig",
|
|
|
|
) -> None:
|
|
|
|
total_cpu_memory = get_cpu_memory()
|
|
|
|
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
|
|
|
|
# group are in the same node. However, the GPUs may span multiple nodes.
|
|
|
|
num_gpus_per_node = parallel_config.tensor_parallel_size
|
|
|
|
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
|
|
|
|
|
2023-07-03 11:31:55 -07:00
|
|
|
msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
|
|
|
|
f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
|
|
|
|
"allocated for the swap space.")
|
2023-05-23 18:22:26 -07:00
|
|
|
if cpu_memory_usage > 0.7 * total_cpu_memory:
|
|
|
|
raise ValueError("Too large swap space. " + msg)
|
|
|
|
elif cpu_memory_usage > 0.4 * total_cpu_memory:
|
2023-07-03 11:31:55 -07:00
|
|
|
logger.warning("Possibly too large swap space. " + msg)
|
2023-05-23 18:22:26 -07:00
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
|
|
|
|
class ParallelConfig:
|
2023-06-07 18:25:20 +08:00
|
|
|
"""Configuration for the distributed execution.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
pipeline_parallel_size: Number of pipeline parallel groups.
|
|
|
|
tensor_parallel_size: Number of tensor parallel groups.
|
|
|
|
worker_use_ray: Whether to use Ray for model workers. Will be set to
|
|
|
|
True if either pipeline_parallel_size or tensor_parallel_size is
|
|
|
|
greater than 1.
|
|
|
|
"""
|
2023-07-03 11:31:55 -07:00
|
|
|
|
2023-05-20 13:06:59 -07:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
pipeline_parallel_size: int,
|
|
|
|
tensor_parallel_size: int,
|
2023-06-05 23:44:50 +08:00
|
|
|
worker_use_ray: bool,
|
2023-05-20 13:06:59 -07:00
|
|
|
) -> None:
|
|
|
|
self.pipeline_parallel_size = pipeline_parallel_size
|
|
|
|
self.tensor_parallel_size = tensor_parallel_size
|
2023-06-05 23:44:50 +08:00
|
|
|
self.worker_use_ray = worker_use_ray
|
2023-05-20 13:06:59 -07:00
|
|
|
|
|
|
|
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
|
|
|
if self.world_size > 1:
|
2023-06-05 23:44:50 +08:00
|
|
|
self.worker_use_ray = True
|
2023-05-20 13:06:59 -07:00
|
|
|
self._verify_args()
|
|
|
|
|
|
|
|
def _verify_args(self) -> None:
|
|
|
|
if self.pipeline_parallel_size > 1:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Pipeline parallelism is not supported yet.")
|
|
|
|
|
|
|
|
|
|
|
|
class SchedulerConfig:
|
2023-06-07 18:25:20 +08:00
|
|
|
"""Scheduler configuration.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
max_num_batched_tokens: Maximum number of tokens to be processed in
|
|
|
|
a single iteration.
|
|
|
|
max_num_seqs: Maximum number of sequences to be processed in a single
|
|
|
|
iteration.
|
2023-08-01 04:11:57 +08:00
|
|
|
max_model_len: Maximum length of a sequence (including prompt
|
2023-06-30 18:48:49 -07:00
|
|
|
and generated text).
|
2023-06-07 18:25:20 +08:00
|
|
|
"""
|
2023-07-03 11:31:55 -07:00
|
|
|
|
|
|
|
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
|
2023-07-17 23:20:20 -07:00
|
|
|
max_model_len: int) -> None:
|
2023-05-20 13:06:59 -07:00
|
|
|
self.max_num_batched_tokens = max_num_batched_tokens
|
|
|
|
self.max_num_seqs = max_num_seqs
|
2023-07-17 23:20:20 -07:00
|
|
|
self.max_model_len = max_model_len
|
2023-05-20 13:06:59 -07:00
|
|
|
|
|
|
|
|
|
|
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
|
|
"half": torch.float16,
|
|
|
|
"float16": torch.float16,
|
|
|
|
"float": torch.float32,
|
|
|
|
"float32": torch.float32,
|
|
|
|
"bfloat16": torch.bfloat16,
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def _get_and_verify_dtype(
|
|
|
|
config: PretrainedConfig,
|
|
|
|
dtype: str,
|
|
|
|
) -> torch.dtype:
|
|
|
|
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
|
|
|
# because config.torch_dtype can be None.
|
|
|
|
config_dtype = getattr(config, "torch_dtype", None)
|
|
|
|
if config_dtype is None:
|
|
|
|
config_dtype = torch.float32
|
|
|
|
|
|
|
|
dtype = dtype.lower()
|
2023-06-07 00:40:21 -07:00
|
|
|
if dtype == "auto":
|
2023-05-20 13:06:59 -07:00
|
|
|
if config_dtype == torch.float32:
|
|
|
|
# Following the common practice, we use float16 for float32 models.
|
|
|
|
torch_dtype = torch.float16
|
|
|
|
else:
|
|
|
|
torch_dtype = config_dtype
|
|
|
|
else:
|
2023-05-21 17:04:18 -07:00
|
|
|
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
|
|
|
raise ValueError(f"Unknown dtype: {dtype}")
|
2023-05-20 13:06:59 -07:00
|
|
|
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
|
|
|
|
|
|
|
# Verify the dtype.
|
|
|
|
if torch_dtype != config_dtype:
|
|
|
|
if torch_dtype == torch.float32:
|
|
|
|
# Upcasting to float32 is allowed.
|
|
|
|
pass
|
|
|
|
elif config_dtype == torch.float32:
|
|
|
|
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
|
|
|
pass
|
|
|
|
else:
|
2023-06-07 00:40:21 -07:00
|
|
|
# Casting between float16 and bfloat16 is allowed with a warning.
|
2023-07-03 11:31:55 -07:00
|
|
|
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
|
2023-05-20 13:06:59 -07:00
|
|
|
|
|
|
|
# Check if the GPU supports the dtype.
|
|
|
|
if torch_dtype == torch.bfloat16:
|
|
|
|
compute_capability = torch.cuda.get_device_capability()
|
|
|
|
if compute_capability[0] < 8:
|
|
|
|
gpu_name = torch.cuda.get_device_name()
|
|
|
|
raise ValueError(
|
|
|
|
"Bfloat16 is only supported on GPUs with compute capability "
|
|
|
|
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
|
|
|
f"{compute_capability[0]}.{compute_capability[1]}.")
|
|
|
|
return torch_dtype
|