[Frontend] Dynamic RoPE scaling (#4638)
This commit is contained in:
parent
99eff67ba9
commit
9b9a10d6cb
@ -37,3 +37,57 @@ def test_get_sliding_window():
|
||||
|
||||
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
|
||||
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
||||
|
||||
|
||||
def test_rope_scaling():
|
||||
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0}
|
||||
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}
|
||||
|
||||
llama_model_config = ModelConfig(
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
)
|
||||
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
|
||||
assert llama_model_config.max_model_len == 8192
|
||||
|
||||
llama_model_config = ModelConfig(
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
rope_scaling=TEST_ROPE_SCALING,
|
||||
)
|
||||
assert getattr(llama_model_config.hf_config, "rope_scaling",
|
||||
None) == TEST_ROPE_SCALING
|
||||
assert llama_model_config.max_model_len == 16384
|
||||
|
||||
longchat_model_config = ModelConfig(
|
||||
"lmsys/longchat-13b-16k",
|
||||
"lmsys/longchat-13b-16k",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
)
|
||||
assert getattr(longchat_model_config.hf_config, "rope_scaling",
|
||||
None) == LONGCHAT_ROPE_SCALING
|
||||
assert longchat_model_config.max_model_len == 16384
|
||||
|
||||
longchat_model_config = ModelConfig(
|
||||
"lmsys/longchat-13b-16k",
|
||||
"lmsys/longchat-13b-16k",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
rope_scaling=TEST_ROPE_SCALING,
|
||||
)
|
||||
assert getattr(longchat_model_config.hf_config, "rope_scaling",
|
||||
None) == TEST_ROPE_SCALING
|
||||
assert longchat_model_config.max_model_len == 4096
|
||||
|
@ -45,6 +45,9 @@ class ModelConfig:
|
||||
code_revision: The specific revision to use for the model code on
|
||||
Hugging Face Hub. It can be a branch name, a tag name, or a
|
||||
commit id. If unspecified, will use the default version.
|
||||
rope_scaling: Dictionary containing the scaling configuration for the
|
||||
RoPE embeddings. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
branch name, a tag name, or a commit id. If unspecified, will use
|
||||
the default version.
|
||||
@ -84,6 +87,7 @@ class ModelConfig:
|
||||
seed: int,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
@ -102,6 +106,7 @@ class ModelConfig:
|
||||
self.seed = seed
|
||||
self.revision = revision
|
||||
self.code_revision = code_revision
|
||||
self.rope_scaling = rope_scaling
|
||||
self.tokenizer_revision = tokenizer_revision
|
||||
self.quantization = quantization
|
||||
self.quantization_param_path = quantization_param_path
|
||||
@ -116,7 +121,7 @@ class ModelConfig:
|
||||
self.skip_tokenizer_init = skip_tokenizer_init
|
||||
|
||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision)
|
||||
code_revision, rope_scaling)
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
|
||||
|
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@ -49,6 +50,7 @@ class EngineArgs:
|
||||
disable_log_stats: bool = False
|
||||
revision: Optional[str] = None
|
||||
code_revision: Optional[str] = None
|
||||
rope_scaling: Optional[dict] = None
|
||||
tokenizer_revision: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
enforce_eager: bool = False
|
||||
@ -330,6 +332,11 @@ class EngineArgs:
|
||||
'None, we assume the model weights are not '
|
||||
'quantized and use `dtype` to determine the data '
|
||||
'type of the weights.')
|
||||
parser.add_argument('--rope-scaling',
|
||||
default=None,
|
||||
type=json.loads,
|
||||
help='RoPE scaling configuration in JSON format. '
|
||||
'For example, {"type":"dynamic","factor":2.0}')
|
||||
parser.add_argument('--enforce-eager',
|
||||
action='store_true',
|
||||
help='Always use eager-mode PyTorch. If False, '
|
||||
@ -548,11 +555,12 @@ class EngineArgs:
|
||||
model_config = ModelConfig(
|
||||
self.model, self.tokenizer, self.tokenizer_mode,
|
||||
self.trust_remote_code, self.dtype, self.seed, self.revision,
|
||||
self.code_revision, self.tokenizer_revision, self.max_model_len,
|
||||
self.quantization, self.quantization_param_path,
|
||||
self.enforce_eager, self.max_context_len_to_capture,
|
||||
self.max_seq_len_to_capture, self.max_logprobs,
|
||||
self.skip_tokenizer_init, self.served_model_name)
|
||||
self.code_revision, self.rope_scaling, self.tokenizer_revision,
|
||||
self.max_model_len, self.quantization,
|
||||
self.quantization_param_path, self.enforce_eager,
|
||||
self.max_context_len_to_capture, self.max_seq_len_to_capture,
|
||||
self.max_logprobs, self.skip_tokenizer_init,
|
||||
self.served_model_name)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space, self.kv_cache_dtype,
|
||||
|
@ -104,10 +104,11 @@ class LLMEngine:
|
||||
"Initializing an LLM engine (v%s) with config: "
|
||||
"model=%r, speculative_config=%r, tokenizer=%r, "
|
||||
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
||||
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
|
||||
"max_seq_len=%d, download_dir=%r, load_format=%s, "
|
||||
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
|
||||
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
|
||||
"rope_scaling=%r, tokenizer_revision=%s, "
|
||||
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
|
||||
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
|
||||
"disable_custom_all_reduce=%s, quantization=%s, "
|
||||
"enforce_eager=%s, kv_cache_dtype=%s, "
|
||||
"quantization_param_path=%s, device_config=%s, "
|
||||
"decoding_config=%r, seed=%d, served_model_name=%s)",
|
||||
vllm.__version__,
|
||||
@ -117,6 +118,7 @@ class LLMEngine:
|
||||
model_config.skip_tokenizer_init,
|
||||
model_config.tokenizer_mode,
|
||||
model_config.revision,
|
||||
model_config.rope_scaling,
|
||||
model_config.tokenizer_revision,
|
||||
model_config.trust_remote_code,
|
||||
model_config.dtype,
|
||||
|
@ -2,9 +2,12 @@ from typing import Dict, Optional
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||
JAISConfig, MPTConfig, RWConfig)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
|
||||
"chatglm": ChatGLMConfig,
|
||||
"dbrx": DbrxConfig,
|
||||
@ -18,7 +21,8 @@ _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
|
||||
def get_config(model: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None) -> PretrainedConfig:
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None) -> PretrainedConfig:
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model,
|
||||
@ -41,6 +45,10 @@ def get_config(model: str,
|
||||
config = config_class.from_pretrained(model,
|
||||
revision=revision,
|
||||
code_revision=code_revision)
|
||||
if rope_scaling is not None:
|
||||
logger.info("Updating rope_scaling from %r to %r",
|
||||
getattr(config, "rope_scaling", None), rope_scaling)
|
||||
config.update({"rope_scaling": rope_scaling})
|
||||
return config
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user