Make max_model_len
configurable (#972)
This commit is contained in:
parent
d6545ad22e
commit
0bb1e885a0
@ -38,6 +38,8 @@ class ModelConfig:
|
|||||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||||
for BF16 models.
|
for BF16 models.
|
||||||
seed: Random seed for reproducibility.
|
seed: Random seed for reproducibility.
|
||||||
|
max_model_len: Maximum length of a sequence (including prompt and
|
||||||
|
output). If None, will be derived from the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -50,6 +52,7 @@ class ModelConfig:
|
|||||||
load_format: str,
|
load_format: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
|
max_model_len: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
@ -63,6 +66,16 @@ class ModelConfig:
|
|||||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||||
self._verify_load_format()
|
self._verify_load_format()
|
||||||
self._verify_tokenizer_mode()
|
self._verify_tokenizer_mode()
|
||||||
|
self.max_model_len = None
|
||||||
|
if max_model_len is not None:
|
||||||
|
derived_max_model_len = self.get_max_model_len()
|
||||||
|
if max_model_len > derived_max_model_len:
|
||||||
|
logger.warning(
|
||||||
|
f"User-specified max_model_len ({max_model_len}) is "
|
||||||
|
f"greater than the derived max_model_len "
|
||||||
|
f"({derived_max_model_len}). Make sure the value is "
|
||||||
|
"correct and within the model context size.")
|
||||||
|
self.max_model_len = max_model_len
|
||||||
|
|
||||||
def _verify_load_format(self) -> None:
|
def _verify_load_format(self) -> None:
|
||||||
load_format = self.load_format.lower()
|
load_format = self.load_format.lower()
|
||||||
@ -134,6 +147,8 @@ class ModelConfig:
|
|||||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
def get_max_model_len(self) -> int:
|
def get_max_model_len(self) -> int:
|
||||||
|
if self.max_model_len is not None:
|
||||||
|
return self.max_model_len
|
||||||
max_model_len = float("inf")
|
max_model_len = float("inf")
|
||||||
possible_keys = [
|
possible_keys = [
|
||||||
# OPT
|
# OPT
|
||||||
|
@ -18,6 +18,7 @@ class EngineArgs:
|
|||||||
load_format: str = 'auto'
|
load_format: str = 'auto'
|
||||||
dtype: str = 'auto'
|
dtype: str = 'auto'
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
|
max_model_len: Optional[int] = None
|
||||||
worker_use_ray: bool = False
|
worker_use_ray: bool = False
|
||||||
pipeline_parallel_size: int = 1
|
pipeline_parallel_size: int = 1
|
||||||
tensor_parallel_size: int = 1
|
tensor_parallel_size: int = 1
|
||||||
@ -89,6 +90,11 @@ class EngineArgs:
|
|||||||
'The "auto" option will use FP16 precision '
|
'The "auto" option will use FP16 precision '
|
||||||
'for FP32 and FP16 models, and BF16 precision '
|
'for FP32 and FP16 models, and BF16 precision '
|
||||||
'for BF16 models.')
|
'for BF16 models.')
|
||||||
|
parser.add_argument('--max-model-len',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='model context length. If unspecified, '
|
||||||
|
'will be automatically derived from the model.')
|
||||||
# Parallel arguments
|
# Parallel arguments
|
||||||
parser.add_argument('--worker-use-ray',
|
parser.add_argument('--worker-use-ray',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@ -153,7 +159,7 @@ class EngineArgs:
|
|||||||
model_config = ModelConfig(self.model, self.tokenizer,
|
model_config = ModelConfig(self.model, self.tokenizer,
|
||||||
self.tokenizer_mode, self.trust_remote_code,
|
self.tokenizer_mode, self.trust_remote_code,
|
||||||
self.download_dir, self.load_format,
|
self.download_dir, self.load_format,
|
||||||
self.dtype, self.seed)
|
self.dtype, self.seed, self.max_model_len)
|
||||||
cache_config = CacheConfig(self.block_size,
|
cache_config = CacheConfig(self.block_size,
|
||||||
self.gpu_memory_utilization,
|
self.gpu_memory_utilization,
|
||||||
self.swap_space)
|
self.swap_space)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user