[Misc] Standardize RoPE handling for Qwen2-VL (#9250)

This commit is contained in:
Cyrus Leung 2024-10-16 13:56:17 +08:00 committed by GitHub
parent ed920135c8
commit 7e7eae338d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 102 additions and 200 deletions

View File

@ -31,7 +31,7 @@ def benchmark_rope_kernels_multi_lora(
# batched RoPE can take multiple scaling factors # batched RoPE can take multiple scaling factors
batched_rope = get_rope(head_size, rotary_dim, max_position, base, batched_rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, { is_neox_style, {
"type": "linear", "rope_type": "linear",
"factor": tuple(scaling_factors) "factor": tuple(scaling_factors)
}) })
# non-batched RoPE takes only one scaling factor, we create multiple # non-batched RoPE takes only one scaling factor, we create multiple
@ -41,7 +41,7 @@ def benchmark_rope_kernels_multi_lora(
non_batched_ropes.append( non_batched_ropes.append(
get_rope(head_size, rotary_dim, max_position, base, is_neox_style, get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
{ {
"type": "linear", "rope_type": "linear",
"factor": (scaling_factor, ) "factor": (scaling_factor, )
})) }))

View File

@ -4,7 +4,7 @@ numpy < 2.0.0
requests >= 2.26.0 requests >= 2.26.0
tqdm tqdm
py-cpuinfo py-cpuinfo
transformers >= 4.45.0 # Required for Llama 3.2. transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL.
tokenizers >= 0.19.1 # Required for Llama 3. tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer. protobuf # Required by LlamaTokenizer.
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9' fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'

View File

@ -105,7 +105,7 @@ def test_batched_rotary_embedding(
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"type": "linear", "rope_type": "linear",
"factor": (1, ) "factor": (1, )
}) })
rope = rope.to(dtype=dtype) rope = rope.to(dtype=dtype)
@ -166,7 +166,7 @@ def test_batched_rotary_embedding_multi_lora(
rotary_dim = head_size rotary_dim = head_size
scaling_factors: List[int] = [1, 2, 4] scaling_factors: List[int] = [1, 2, 4]
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"type": "linear", "rope_type": "linear",
"factor": tuple(scaling_factors) "factor": tuple(scaling_factors)
}) })
rope = rope.to(dtype=dtype) rope = rope.to(dtype=dtype)
@ -211,10 +211,10 @@ def test_rope_module_cache():
MAX_POSITIONS = [123, 1234] MAX_POSITIONS = [123, 1234]
BASES = [10000, 1000000] BASES = [10000, 1000000]
ROPE_SCALINGS = (None, { ROPE_SCALINGS = (None, {
"type": "linear", "rope_type": "linear",
"factor": (1, ) "factor": (1, )
}, { }, {
"type": "dynamic", "rope_type": "dynamic",
"factor": 1 "factor": 1
}) })
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE, settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,

View File

@ -951,7 +951,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
lora_rope.create_lora_weights(max_loras, lora_config) lora_rope.create_lora_weights(max_loras, lora_config)
linear_rope = get_rope(head_size, rotary_dim, max_position, base, linear_rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, { is_neox_style, {
"type": "linear", "rope_type": "linear",
"factor": scaling_factors "factor": scaling_factors
}) })
linear_rope = linear_rope.to(dtype=dtype) linear_rope = linear_rope.to(dtype=dtype)

View File

@ -64,9 +64,9 @@ def test_get_sliding_window():
def test_rope_customization(): def test_rope_customization():
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0} TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
TEST_ROPE_THETA = 16_000_000.0 TEST_ROPE_THETA = 16_000_000.0
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0} LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
llama_model_config = ModelConfig( llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct",

View File

@ -1739,16 +1739,10 @@ def _get_and_verify_max_len(
rope_scaling = getattr(hf_config, "rope_scaling", None) rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None: if rope_scaling is not None:
if "type" in rope_scaling: # No need to consider "type" key because of patch_rope_scaling when
rope_type = rope_scaling["type"] # loading HF config
elif "rope_type" in rope_scaling: rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling["rope_type"]
else:
raise ValueError(
"rope_scaling must have a 'type' or 'rope_type' key.")
# The correct one should be "longrope", kept "su" here
# to be backward compatible
if rope_type not in ("su", "longrope", "llama3"): if rope_type not in ("su", "longrope", "llama3"):
if disable_sliding_window: if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling # TODO(robertgshaw): Find a model that supports rope_scaling
@ -1758,11 +1752,10 @@ def _get_and_verify_max_len(
"with rope_scaling. Please raise an issue so we can " "with rope_scaling. Please raise an issue so we can "
"investigate.") "investigate.")
if rope_type == "mrope": # NOTE: rope_type == "default" does not define factor
scaling_factor = 1 # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
else: scaling_factor = rope_scaling.get("factor", 1.0)
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_type == "yarn": if rope_type == "yarn":
derived_max_model_len = rope_scaling[ derived_max_model_len = rope_scaling[
"original_max_position_embeddings"] "original_max_position_embeddings"]

View File

@ -454,11 +454,12 @@ class EngineArgs:
'None, we assume the model weights are not ' 'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data ' 'quantized and use `dtype` to determine the data '
'type of the weights.') 'type of the weights.')
parser.add_argument('--rope-scaling', parser.add_argument(
default=None, '--rope-scaling',
type=json.loads, default=None,
help='RoPE scaling configuration in JSON format. ' type=json.loads,
'For example, {"type":"dynamic","factor":2.0}') help='RoPE scaling configuration in JSON format. '
'For example, {"rope_type":"dynamic","factor":2.0}')
parser.add_argument('--rope-theta', parser.add_argument('--rope-theta',
default=None, default=None,
type=float, type=float,

View File

@ -920,13 +920,10 @@ def get_rope(
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype) is_neox_style, dtype)
else: else:
scaling_type = rope_scaling[ scaling_type = rope_scaling["rope_type"]
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if scaling_type not in {"su", "longrope"}:
scaling_factor = rope_scaling.get("factor", 1.0)
if scaling_type == "llama3": if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"] low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling[ original_max_position = rope_scaling[
@ -937,16 +934,39 @@ def get_rope(
scaling_factor, low_freq_factor, scaling_factor, low_freq_factor,
high_freq_factor, high_freq_factor,
original_max_position) original_max_position)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "linear": elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base, max_position, base,
is_neox_style, is_neox_style,
scaling_factor, dtype) scaling_factor, dtype)
elif scaling_type == "dynamic": elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding( rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype) scaling_factor, dtype)
elif scaling_type == "yarn": elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[ original_max_position = rope_scaling[
"original_max_position_embeddings"] "original_max_position_embeddings"]
extra_kwargs = { extra_kwargs = {
@ -961,6 +981,7 @@ def get_rope(
scaling_factor, dtype, scaling_factor, dtype,
**extra_kwargs) **extra_kwargs)
elif scaling_type == "deepseek_yarn": elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[ original_max_position = rope_scaling[
"original_max_position_embeddings"] "original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor # assert max_position == original_max_position * scaling_factor
@ -973,9 +994,7 @@ def get_rope(
rotary_emb = DeepseekScalingRotaryEmbedding( rotary_emb = DeepseekScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base, head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs) is_neox_style, scaling_factor, dtype, **extra_kwargs)
# The correct one should be "longrope" but keep "su" here elif scaling_type == "longrope":
# for backward compatible
elif scaling_type == "su" or scaling_type == "longrope":
short_factor = rope_scaling["short_factor"] short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"] long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[ original_max_position = rope_scaling[
@ -989,16 +1008,6 @@ def get_rope(
head_size, rotary_dim, max_position, original_max_position, head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor, base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs) **extra_kwargs)
elif scaling_type == "mrope":
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else: else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb _ROPE_DICT[key] = rotary_emb

View File

@ -242,7 +242,7 @@ class DeepseekV2Attention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj") prefix=f"{prefix}.o_proj")
rope_scaling['type'] = 'deepseek_yarn' rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim, self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim, rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,

View File

@ -179,7 +179,7 @@ class Phi3SmallSelfAttention(nn.Module):
rope_scaling["factor"] = self.rope_position_scale rope_scaling["factor"] = self.rope_position_scale
else: else:
rope_scaling = { rope_scaling = {
"type": "linear", "rope_type": "linear",
"factor": self.rope_position_scale, "factor": self.rope_position_scale,
} }

View File

@ -34,6 +34,8 @@ from PIL import Image
from transformers.image_utils import (get_image_size, from transformers.image_utils import (get_image_size,
infer_channel_dimension_format, infer_channel_dimension_format,
to_numpy_array) to_numpy_array)
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLConfig, Qwen2VLVisionConfig)
from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
make_batched_images, make_batched_videos, smart_resize) make_batched_images, make_batched_videos, smart_resize)
@ -62,8 +64,7 @@ from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.image import cached_get_image_processor
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig, from vllm.transformers_utils.config import uses_mrope
Qwen2VLVisionConfig)
from vllm.transformers_utils.processor import get_processor from vllm.transformers_utils.processor import get_processor
from vllm.utils import is_cpu from vllm.utils import is_cpu
@ -1061,8 +1062,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
if image_input is None and video_input is None: if image_input is None and video_input is None:
inputs_embeds = None inputs_embeds = None
else: else:
rope_scaling = getattr(self.config, "rope_scaling", {}) if uses_mrope(self.config):
if rope_scaling.get("type", None) == "mrope":
assert positions.ndim == 2 and positions.size(0) == 3, ( assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}") f"(3, seq_len) positions, but got {positions.size()}")

View File

@ -23,8 +23,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
MedusaConfig, MllamaConfig, MedusaConfig, MllamaConfig,
MLPSpeculatorConfig, MPTConfig, MLPSpeculatorConfig, MPTConfig,
NemotronConfig, NVLM_D_Config, NemotronConfig, NVLM_D_Config,
Qwen2VLConfig, RWConfig, RWConfig, SolarConfig,
SolarConfig, UltravoxConfig) UltravoxConfig)
# yapf: enable # yapf: enable
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
@ -57,7 +57,6 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"NVLM_D": NVLM_D_Config, "NVLM_D": NVLM_D_Config,
"solar": SolarConfig, "solar": SolarConfig,
"ultravox": UltravoxConfig, "ultravox": UltravoxConfig,
"qwen2_vl": Qwen2VLConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF **_CONFIG_REGISTRY_OVERRIDE_HF
} }
@ -91,6 +90,43 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
return False return False
def patch_rope_scaling(config: PretrainedConfig) -> None:
"""Provide backwards compatibility for RoPE."""
text_config = getattr(config, "text_config", None)
if text_config is not None:
patch_rope_scaling(text_config)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None:
patch_rope_scaling_dict(rope_scaling)
def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None:
if "rope_type" not in rope_scaling and "type" in rope_scaling:
rope_scaling["rope_type"] = rope_scaling["type"]
logger.info("Replacing legacy 'type' key with 'rope_type'")
if "rope_type" not in rope_scaling:
raise ValueError("rope_scaling should have a 'rope_type' key")
if rope_scaling["rope_type"] == "su":
rope_scaling["rope_type"] = "longrope"
logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
elif rope_scaling["rope_type"] == "mrope":
assert "mrope_section" in rope_scaling
rope_scaling["rope_type"] = "default"
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")
def uses_mrope(config: PretrainedConfig) -> bool:
"""Detect if the model with this config uses M-ROPE."""
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is None:
return False
return "mrope_section" in rope_scaling
def get_config( def get_config(
model: Union[str, Path], model: Union[str, Path],
trust_remote_code: bool, trust_remote_code: bool,
@ -191,6 +227,8 @@ def get_config(
) )
config.update({key: value}) config.update({key: value})
patch_rope_scaling(config)
return config return config

View File

@ -14,8 +14,6 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
Qwen2VLVisionConfig)
from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.solar import SolarConfig
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
@ -35,6 +33,4 @@ __all__ = [
"NVLM_D_Config", "NVLM_D_Config",
"SolarConfig", "SolarConfig",
"UltravoxConfig", "UltravoxConfig",
"Qwen2VLConfig",
"Qwen2VLVisionConfig",
] ]

View File

@ -1,131 +0,0 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2VL model configuration"""
import os
from typing import Union
from transformers import PretrainedConfig
class Qwen2VLVisionConfig(PretrainedConfig):
model_type = "qwen2_vl"
def __init__(
self,
depth=32,
embed_dim=1280,
hidden_size=3584,
hidden_act="quick_gelu",
mlp_ratio=4,
num_heads=16,
in_channels=3,
patch_size=14,
spatial_merge_size=2,
temporal_patch_size=2,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.embed_dim = embed_dim
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str,
os.PathLike],
**kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs)
if config_dict.get("model_type") == "qwen2_vl":
config_dict = config_dict["vision_config"]
return cls.from_dict(config_dict, **kwargs)
class Qwen2VLConfig(PretrainedConfig):
def __init__(
self,
vocab_size=152064,
hidden_size=8192,
intermediate_size=29568,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling=None,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = Qwen2VLVisionConfig(**vision_config)
elif vision_config is None:
self.vision_config = Qwen2VLVisionConfig()
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling
# NOTE: the following section from original transformers config
# for Qwen2-VL is commented out to address rope config loading issue
#
# if self.rope_scaling is not None and "type" in self.rope_scaling:
# if self.rope_scaling["type"] == "mrope":
# self.rope_scaling["type"] = "default"
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
# rope_config_validation(self)
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

View File

@ -19,6 +19,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SequenceData, from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import make_tensor_with_pad from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
@ -439,10 +440,7 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
def model_is_mrope(self) -> bool: def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type. """Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases.""" mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) return uses_mrope(self.model_config.hf_config)
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(model_config=self.model_config, self.model = get_model(model_config=self.model_config,

View File

@ -47,6 +47,7 @@ from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager) LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available, flatten_2d_lists, is_hip, is_pin_memory_available,
supports_dynamo) supports_dynamo)
@ -1379,10 +1380,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def model_is_mrope(self) -> bool: def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type. """Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases.""" mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) return uses_mrope(self.model_config.hf_config)
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: