[Frontend] Dynamic RoPE scaling (#4638)

This commit is contained in:
sasha0552 2024-05-22 05:32:35 +00:00 committed by GitHub
parent 99eff67ba9
commit 9b9a10d6cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 89 additions and 12 deletions

View File

@ -37,3 +37,57 @@ def test_get_sliding_window():
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
def test_rope_scaling():
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0}
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}
llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
)
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
assert llama_model_config.max_model_len == 8192
llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
rope_scaling=TEST_ROPE_SCALING,
)
assert getattr(llama_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
assert llama_model_config.max_model_len == 16384
longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
)
assert getattr(longchat_model_config.hf_config, "rope_scaling",
None) == LONGCHAT_ROPE_SCALING
assert longchat_model_config.max_model_len == 16384
longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
rope_scaling=TEST_ROPE_SCALING,
)
assert getattr(longchat_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
assert longchat_model_config.max_model_len == 4096

View File

@ -45,6 +45,9 @@ class ModelConfig:
code_revision: The specific revision to use for the model code on code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version. commit id. If unspecified, will use the default version.
rope_scaling: Dictionary containing the scaling configuration for the
RoPE embeddings. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
tokenizer_revision: The specific tokenizer version to use. It can be a tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use branch name, a tag name, or a commit id. If unspecified, will use
the default version. the default version.
@ -84,6 +87,7 @@ class ModelConfig:
seed: int, seed: int,
revision: Optional[str] = None, revision: Optional[str] = None,
code_revision: Optional[str] = None, code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
quantization: Optional[str] = None, quantization: Optional[str] = None,
@ -102,6 +106,7 @@ class ModelConfig:
self.seed = seed self.seed = seed
self.revision = revision self.revision = revision
self.code_revision = code_revision self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.tokenizer_revision = tokenizer_revision self.tokenizer_revision = tokenizer_revision
self.quantization = quantization self.quantization = quantization
self.quantization_param_path = quantization_param_path self.quantization_param_path = quantization_param_path
@ -116,7 +121,7 @@ class ModelConfig:
self.skip_tokenizer_init = skip_tokenizer_init self.skip_tokenizer_init = skip_tokenizer_init
self.hf_config = get_config(self.model, trust_remote_code, revision, self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision) code_revision, rope_scaling)
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config, self.max_model_len = _get_and_verify_max_len(self.hf_text_config,

View File

@ -1,5 +1,6 @@
import argparse import argparse
import dataclasses import dataclasses
import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@ -49,6 +50,7 @@ class EngineArgs:
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
code_revision: Optional[str] = None code_revision: Optional[str] = None
rope_scaling: Optional[dict] = None
tokenizer_revision: Optional[str] = None tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None quantization: Optional[str] = None
enforce_eager: bool = False enforce_eager: bool = False
@ -330,6 +332,11 @@ class EngineArgs:
'None, we assume the model weights are not ' 'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data ' 'quantized and use `dtype` to determine the data '
'type of the weights.') 'type of the weights.')
parser.add_argument('--rope-scaling',
default=None,
type=json.loads,
help='RoPE scaling configuration in JSON format. '
'For example, {"type":"dynamic","factor":2.0}')
parser.add_argument('--enforce-eager', parser.add_argument('--enforce-eager',
action='store_true', action='store_true',
help='Always use eager-mode PyTorch. If False, ' help='Always use eager-mode PyTorch. If False, '
@ -548,11 +555,12 @@ class EngineArgs:
model_config = ModelConfig( model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode, self.model, self.tokenizer, self.tokenizer_mode,
self.trust_remote_code, self.dtype, self.seed, self.revision, self.trust_remote_code, self.dtype, self.seed, self.revision,
self.code_revision, self.tokenizer_revision, self.max_model_len, self.code_revision, self.rope_scaling, self.tokenizer_revision,
self.quantization, self.quantization_param_path, self.max_model_len, self.quantization,
self.enforce_eager, self.max_context_len_to_capture, self.quantization_param_path, self.enforce_eager,
self.max_seq_len_to_capture, self.max_logprobs, self.max_context_len_to_capture, self.max_seq_len_to_capture,
self.skip_tokenizer_init, self.served_model_name) self.max_logprobs, self.skip_tokenizer_init,
self.served_model_name)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype, self.swap_space, self.kv_cache_dtype,

View File

@ -104,10 +104,11 @@ class LLMEngine:
"Initializing an LLM engine (v%s) with config: " "Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, " "model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, " "rope_scaling=%r, tokenizer_revision=%s, "
"max_seq_len=%d, download_dir=%r, load_format=%s, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, " "disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, " "quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d, served_model_name=%s)", "decoding_config=%r, seed=%d, served_model_name=%s)",
vllm.__version__, vllm.__version__,
@ -117,6 +118,7 @@ class LLMEngine:
model_config.skip_tokenizer_init, model_config.skip_tokenizer_init,
model_config.tokenizer_mode, model_config.tokenizer_mode,
model_config.revision, model_config.revision,
model_config.rope_scaling,
model_config.tokenizer_revision, model_config.tokenizer_revision,
model_config.trust_remote_code, model_config.trust_remote_code,
model_config.dtype, model_config.dtype,

View File

@ -2,9 +2,12 @@ from typing import Dict, Optional
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MPTConfig, RWConfig) JAISConfig, MPTConfig, RWConfig)
logger = init_logger(__name__)
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,
"dbrx": DbrxConfig, "dbrx": DbrxConfig,
@ -18,7 +21,8 @@ _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
def get_config(model: str, def get_config(model: str,
trust_remote_code: bool, trust_remote_code: bool,
revision: Optional[str] = None, revision: Optional[str] = None,
code_revision: Optional[str] = None) -> PretrainedConfig: code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None) -> PretrainedConfig:
try: try:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, model,
@ -41,6 +45,10 @@ def get_config(model: str,
config = config_class.from_pretrained(model, config = config_class.from_pretrained(model,
revision=revision, revision=revision,
code_revision=code_revision) code_revision=code_revision)
if rope_scaling is not None:
logger.info("Updating rope_scaling from %r to %r",
getattr(config, "rope_scaling", None), rope_scaling)
config.update({"rope_scaling": rope_scaling})
return config return config