diff --git a/vllm/config.py b/vllm/config.py index 567bb44b..ca260f27 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -38,6 +38,8 @@ class ModelConfig: will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. seed: Random seed for reproducibility. + max_model_len: Maximum length of a sequence (including prompt and + output). If None, will be derived from the model. """ def __init__( @@ -50,6 +52,7 @@ class ModelConfig: load_format: str, dtype: str, seed: int, + max_model_len: Optional[int] = None, ) -> None: self.model = model self.tokenizer = tokenizer @@ -63,6 +66,16 @@ class ModelConfig: self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self._verify_load_format() self._verify_tokenizer_mode() + self.max_model_len = None + if max_model_len is not None: + derived_max_model_len = self.get_max_model_len() + if max_model_len > derived_max_model_len: + logger.warning( + f"User-specified max_model_len ({max_model_len}) is " + f"greater than the derived max_model_len " + f"({derived_max_model_len}). Make sure the value is " + "correct and within the model context size.") + self.max_model_len = max_model_len def _verify_load_format(self) -> None: load_format = self.load_format.lower() @@ -134,6 +147,8 @@ class ModelConfig: return total_num_attention_heads // parallel_config.tensor_parallel_size def get_max_model_len(self) -> int: + if self.max_model_len is not None: + return self.max_model_len max_model_len = float("inf") possible_keys = [ # OPT diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7679966c..1e3f9e64 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -18,6 +18,7 @@ class EngineArgs: load_format: str = 'auto' dtype: str = 'auto' seed: int = 0 + max_model_len: Optional[int] = None worker_use_ray: bool = False pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 @@ -89,6 +90,11 @@ class EngineArgs: 'The "auto" option will use FP16 precision ' 'for FP32 and FP16 models, and BF16 precision ' 'for BF16 models.') + parser.add_argument('--max-model-len', + type=int, + default=None, + help='model context length. If unspecified, ' + 'will be automatically derived from the model.') # Parallel arguments parser.add_argument('--worker-use-ray', action='store_true', @@ -153,7 +159,7 @@ class EngineArgs: model_config = ModelConfig(self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.load_format, - self.dtype, self.seed) + self.dtype, self.seed, self.max_model_len) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space)