Make max_model_len configurable (#972)

This commit is contained in:
Antoni Baum 2023-09-12 16:29:19 -07:00 committed by GitHub
parent d6545ad22e
commit 0bb1e885a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 1 deletions

View File

@ -38,6 +38,8 @@ class ModelConfig:
will use FP16 precision for FP32 and FP16 models, and BF16 precision will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models. for BF16 models.
seed: Random seed for reproducibility. seed: Random seed for reproducibility.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
""" """
def __init__( def __init__(
@ -50,6 +52,7 @@ class ModelConfig:
load_format: str, load_format: str,
dtype: str, dtype: str,
seed: int, seed: int,
max_model_len: Optional[int] = None,
) -> None: ) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -63,6 +66,16 @@ class ModelConfig:
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_load_format() self._verify_load_format()
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
self.max_model_len = None
if max_model_len is not None:
derived_max_model_len = self.get_max_model_len()
if max_model_len > derived_max_model_len:
logger.warning(
f"User-specified max_model_len ({max_model_len}) is "
f"greater than the derived max_model_len "
f"({derived_max_model_len}). Make sure the value is "
"correct and within the model context size.")
self.max_model_len = max_model_len
def _verify_load_format(self) -> None: def _verify_load_format(self) -> None:
load_format = self.load_format.lower() load_format = self.load_format.lower()
@ -134,6 +147,8 @@ class ModelConfig:
return total_num_attention_heads // parallel_config.tensor_parallel_size return total_num_attention_heads // parallel_config.tensor_parallel_size
def get_max_model_len(self) -> int: def get_max_model_len(self) -> int:
if self.max_model_len is not None:
return self.max_model_len
max_model_len = float("inf") max_model_len = float("inf")
possible_keys = [ possible_keys = [
# OPT # OPT

View File

@ -18,6 +18,7 @@ class EngineArgs:
load_format: str = 'auto' load_format: str = 'auto'
dtype: str = 'auto' dtype: str = 'auto'
seed: int = 0 seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False worker_use_ray: bool = False
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1 tensor_parallel_size: int = 1
@ -89,6 +90,11 @@ class EngineArgs:
'The "auto" option will use FP16 precision ' 'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision ' 'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.') 'for BF16 models.')
parser.add_argument('--max-model-len',
type=int,
default=None,
help='model context length. If unspecified, '
'will be automatically derived from the model.')
# Parallel arguments # Parallel arguments
parser.add_argument('--worker-use-ray', parser.add_argument('--worker-use-ray',
action='store_true', action='store_true',
@ -153,7 +159,7 @@ class EngineArgs:
model_config = ModelConfig(self.model, self.tokenizer, model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code, self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format, self.download_dir, self.load_format,
self.dtype, self.seed) self.dtype, self.seed, self.max_model_len)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space) self.swap_space)