Make max_model_len configurable (#972)

This commit is contained in:
Antoni Baum 2023-09-12 16:29:19 -07:00 committed by GitHub
parent d6545ad22e
commit 0bb1e885a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 1 deletions

View File

@ -38,6 +38,8 @@ class ModelConfig:
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
seed: Random seed for reproducibility.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
"""
def __init__(
@ -50,6 +52,7 @@ class ModelConfig:
load_format: str,
dtype: str,
seed: int,
max_model_len: Optional[int] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
@ -63,6 +66,16 @@ class ModelConfig:
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_load_format()
self._verify_tokenizer_mode()
self.max_model_len = None
if max_model_len is not None:
derived_max_model_len = self.get_max_model_len()
if max_model_len > derived_max_model_len:
logger.warning(
f"User-specified max_model_len ({max_model_len}) is "
f"greater than the derived max_model_len "
f"({derived_max_model_len}). Make sure the value is "
"correct and within the model context size.")
self.max_model_len = max_model_len
def _verify_load_format(self) -> None:
load_format = self.load_format.lower()
@ -134,6 +147,8 @@ class ModelConfig:
return total_num_attention_heads // parallel_config.tensor_parallel_size
def get_max_model_len(self) -> int:
if self.max_model_len is not None:
return self.max_model_len
max_model_len = float("inf")
possible_keys = [
# OPT

View File

@ -18,6 +18,7 @@ class EngineArgs:
load_format: str = 'auto'
dtype: str = 'auto'
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
@ -89,6 +90,11 @@ class EngineArgs:
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--max-model-len',
type=int,
default=None,
help='model context length. If unspecified, '
'will be automatically derived from the model.')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
@ -153,7 +159,7 @@ class EngineArgs:
model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format,
self.dtype, self.seed)
self.dtype, self.seed, self.max_model_len)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space)