[Misc] Standardize RoPE handling for Qwen2-VL (#9250)
This commit is contained in:
parent
ed920135c8
commit
7e7eae338d
@ -31,7 +31,7 @@ def benchmark_rope_kernels_multi_lora(
|
|||||||
# batched RoPE can take multiple scaling factors
|
# batched RoPE can take multiple scaling factors
|
||||||
batched_rope = get_rope(head_size, rotary_dim, max_position, base,
|
batched_rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||||
is_neox_style, {
|
is_neox_style, {
|
||||||
"type": "linear",
|
"rope_type": "linear",
|
||||||
"factor": tuple(scaling_factors)
|
"factor": tuple(scaling_factors)
|
||||||
})
|
})
|
||||||
# non-batched RoPE takes only one scaling factor, we create multiple
|
# non-batched RoPE takes only one scaling factor, we create multiple
|
||||||
@ -41,7 +41,7 @@ def benchmark_rope_kernels_multi_lora(
|
|||||||
non_batched_ropes.append(
|
non_batched_ropes.append(
|
||||||
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
{
|
{
|
||||||
"type": "linear",
|
"rope_type": "linear",
|
||||||
"factor": (scaling_factor, )
|
"factor": (scaling_factor, )
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ numpy < 2.0.0
|
|||||||
requests >= 2.26.0
|
requests >= 2.26.0
|
||||||
tqdm
|
tqdm
|
||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
transformers >= 4.45.0 # Required for Llama 3.2.
|
transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL.
|
||||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||||
protobuf # Required by LlamaTokenizer.
|
protobuf # Required by LlamaTokenizer.
|
||||||
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
|
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
|
||||||
|
@ -105,7 +105,7 @@ def test_batched_rotary_embedding(
|
|||||||
if rotary_dim is None:
|
if rotary_dim is None:
|
||||||
rotary_dim = head_size
|
rotary_dim = head_size
|
||||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
||||||
"type": "linear",
|
"rope_type": "linear",
|
||||||
"factor": (1, )
|
"factor": (1, )
|
||||||
})
|
})
|
||||||
rope = rope.to(dtype=dtype)
|
rope = rope.to(dtype=dtype)
|
||||||
@ -166,7 +166,7 @@ def test_batched_rotary_embedding_multi_lora(
|
|||||||
rotary_dim = head_size
|
rotary_dim = head_size
|
||||||
scaling_factors: List[int] = [1, 2, 4]
|
scaling_factors: List[int] = [1, 2, 4]
|
||||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
||||||
"type": "linear",
|
"rope_type": "linear",
|
||||||
"factor": tuple(scaling_factors)
|
"factor": tuple(scaling_factors)
|
||||||
})
|
})
|
||||||
rope = rope.to(dtype=dtype)
|
rope = rope.to(dtype=dtype)
|
||||||
@ -211,10 +211,10 @@ def test_rope_module_cache():
|
|||||||
MAX_POSITIONS = [123, 1234]
|
MAX_POSITIONS = [123, 1234]
|
||||||
BASES = [10000, 1000000]
|
BASES = [10000, 1000000]
|
||||||
ROPE_SCALINGS = (None, {
|
ROPE_SCALINGS = (None, {
|
||||||
"type": "linear",
|
"rope_type": "linear",
|
||||||
"factor": (1, )
|
"factor": (1, )
|
||||||
}, {
|
}, {
|
||||||
"type": "dynamic",
|
"rope_type": "dynamic",
|
||||||
"factor": 1
|
"factor": 1
|
||||||
})
|
})
|
||||||
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
|
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
|
||||||
|
@ -951,7 +951,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
|||||||
lora_rope.create_lora_weights(max_loras, lora_config)
|
lora_rope.create_lora_weights(max_loras, lora_config)
|
||||||
linear_rope = get_rope(head_size, rotary_dim, max_position, base,
|
linear_rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||||
is_neox_style, {
|
is_neox_style, {
|
||||||
"type": "linear",
|
"rope_type": "linear",
|
||||||
"factor": scaling_factors
|
"factor": scaling_factors
|
||||||
})
|
})
|
||||||
linear_rope = linear_rope.to(dtype=dtype)
|
linear_rope = linear_rope.to(dtype=dtype)
|
||||||
|
@ -64,9 +64,9 @@ def test_get_sliding_window():
|
|||||||
|
|
||||||
|
|
||||||
def test_rope_customization():
|
def test_rope_customization():
|
||||||
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0}
|
TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
|
||||||
TEST_ROPE_THETA = 16_000_000.0
|
TEST_ROPE_THETA = 16_000_000.0
|
||||||
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}
|
LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
|
||||||
|
|
||||||
llama_model_config = ModelConfig(
|
llama_model_config = ModelConfig(
|
||||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
|
@ -1739,16 +1739,10 @@ def _get_and_verify_max_len(
|
|||||||
|
|
||||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||||
if rope_scaling is not None:
|
if rope_scaling is not None:
|
||||||
if "type" in rope_scaling:
|
# No need to consider "type" key because of patch_rope_scaling when
|
||||||
rope_type = rope_scaling["type"]
|
# loading HF config
|
||||||
elif "rope_type" in rope_scaling:
|
rope_type = rope_scaling["rope_type"]
|
||||||
rope_type = rope_scaling["rope_type"]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"rope_scaling must have a 'type' or 'rope_type' key.")
|
|
||||||
|
|
||||||
# The correct one should be "longrope", kept "su" here
|
|
||||||
# to be backward compatible
|
|
||||||
if rope_type not in ("su", "longrope", "llama3"):
|
if rope_type not in ("su", "longrope", "llama3"):
|
||||||
if disable_sliding_window:
|
if disable_sliding_window:
|
||||||
# TODO(robertgshaw): Find a model that supports rope_scaling
|
# TODO(robertgshaw): Find a model that supports rope_scaling
|
||||||
@ -1758,11 +1752,10 @@ def _get_and_verify_max_len(
|
|||||||
"with rope_scaling. Please raise an issue so we can "
|
"with rope_scaling. Please raise an issue so we can "
|
||||||
"investigate.")
|
"investigate.")
|
||||||
|
|
||||||
if rope_type == "mrope":
|
# NOTE: rope_type == "default" does not define factor
|
||||||
scaling_factor = 1
|
# https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
|
||||||
else:
|
scaling_factor = rope_scaling.get("factor", 1.0)
|
||||||
assert "factor" in rope_scaling
|
|
||||||
scaling_factor = rope_scaling["factor"]
|
|
||||||
if rope_type == "yarn":
|
if rope_type == "yarn":
|
||||||
derived_max_model_len = rope_scaling[
|
derived_max_model_len = rope_scaling[
|
||||||
"original_max_position_embeddings"]
|
"original_max_position_embeddings"]
|
||||||
|
@ -454,11 +454,12 @@ class EngineArgs:
|
|||||||
'None, we assume the model weights are not '
|
'None, we assume the model weights are not '
|
||||||
'quantized and use `dtype` to determine the data '
|
'quantized and use `dtype` to determine the data '
|
||||||
'type of the weights.')
|
'type of the weights.')
|
||||||
parser.add_argument('--rope-scaling',
|
parser.add_argument(
|
||||||
default=None,
|
'--rope-scaling',
|
||||||
type=json.loads,
|
default=None,
|
||||||
help='RoPE scaling configuration in JSON format. '
|
type=json.loads,
|
||||||
'For example, {"type":"dynamic","factor":2.0}')
|
help='RoPE scaling configuration in JSON format. '
|
||||||
|
'For example, {"rope_type":"dynamic","factor":2.0}')
|
||||||
parser.add_argument('--rope-theta',
|
parser.add_argument('--rope-theta',
|
||||||
default=None,
|
default=None,
|
||||||
type=float,
|
type=float,
|
||||||
|
@ -920,13 +920,10 @@ def get_rope(
|
|||||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||||
is_neox_style, dtype)
|
is_neox_style, dtype)
|
||||||
else:
|
else:
|
||||||
scaling_type = rope_scaling[
|
scaling_type = rope_scaling["rope_type"]
|
||||||
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
|
|
||||||
# The correct one should be "longrope" but keep "su" here
|
|
||||||
# for backward compatible
|
|
||||||
if scaling_type not in {"su", "longrope"}:
|
|
||||||
scaling_factor = rope_scaling.get("factor", 1.0)
|
|
||||||
if scaling_type == "llama3":
|
if scaling_type == "llama3":
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
low_freq_factor = rope_scaling["low_freq_factor"]
|
low_freq_factor = rope_scaling["low_freq_factor"]
|
||||||
high_freq_factor = rope_scaling["high_freq_factor"]
|
high_freq_factor = rope_scaling["high_freq_factor"]
|
||||||
original_max_position = rope_scaling[
|
original_max_position = rope_scaling[
|
||||||
@ -937,16 +934,39 @@ def get_rope(
|
|||||||
scaling_factor, low_freq_factor,
|
scaling_factor, low_freq_factor,
|
||||||
high_freq_factor,
|
high_freq_factor,
|
||||||
original_max_position)
|
original_max_position)
|
||||||
|
elif scaling_type == "default":
|
||||||
|
if "mrope_section" in rope_scaling:
|
||||||
|
rotary_emb = MRotaryEmbedding(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
dtype,
|
||||||
|
mrope_section=rope_scaling["mrope_section"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rotary_emb = RotaryEmbedding(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
dtype,
|
||||||
|
)
|
||||||
elif scaling_type == "linear":
|
elif scaling_type == "linear":
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||||
max_position, base,
|
max_position, base,
|
||||||
is_neox_style,
|
is_neox_style,
|
||||||
scaling_factor, dtype)
|
scaling_factor, dtype)
|
||||||
elif scaling_type == "dynamic":
|
elif scaling_type == "dynamic":
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
scaling_factor, dtype)
|
scaling_factor, dtype)
|
||||||
elif scaling_type == "yarn":
|
elif scaling_type == "yarn":
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
original_max_position = rope_scaling[
|
original_max_position = rope_scaling[
|
||||||
"original_max_position_embeddings"]
|
"original_max_position_embeddings"]
|
||||||
extra_kwargs = {
|
extra_kwargs = {
|
||||||
@ -961,6 +981,7 @@ def get_rope(
|
|||||||
scaling_factor, dtype,
|
scaling_factor, dtype,
|
||||||
**extra_kwargs)
|
**extra_kwargs)
|
||||||
elif scaling_type == "deepseek_yarn":
|
elif scaling_type == "deepseek_yarn":
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
original_max_position = rope_scaling[
|
original_max_position = rope_scaling[
|
||||||
"original_max_position_embeddings"]
|
"original_max_position_embeddings"]
|
||||||
# assert max_position == original_max_position * scaling_factor
|
# assert max_position == original_max_position * scaling_factor
|
||||||
@ -973,9 +994,7 @@ def get_rope(
|
|||||||
rotary_emb = DeepseekScalingRotaryEmbedding(
|
rotary_emb = DeepseekScalingRotaryEmbedding(
|
||||||
head_size, rotary_dim, original_max_position, base,
|
head_size, rotary_dim, original_max_position, base,
|
||||||
is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
||||||
# The correct one should be "longrope" but keep "su" here
|
elif scaling_type == "longrope":
|
||||||
# for backward compatible
|
|
||||||
elif scaling_type == "su" or scaling_type == "longrope":
|
|
||||||
short_factor = rope_scaling["short_factor"]
|
short_factor = rope_scaling["short_factor"]
|
||||||
long_factor = rope_scaling["long_factor"]
|
long_factor = rope_scaling["long_factor"]
|
||||||
original_max_position = rope_scaling[
|
original_max_position = rope_scaling[
|
||||||
@ -989,16 +1008,6 @@ def get_rope(
|
|||||||
head_size, rotary_dim, max_position, original_max_position,
|
head_size, rotary_dim, max_position, original_max_position,
|
||||||
base, is_neox_style, dtype, short_factor, long_factor,
|
base, is_neox_style, dtype, short_factor, long_factor,
|
||||||
**extra_kwargs)
|
**extra_kwargs)
|
||||||
elif scaling_type == "mrope":
|
|
||||||
rotary_emb = MRotaryEmbedding(
|
|
||||||
head_size,
|
|
||||||
rotary_dim,
|
|
||||||
max_position,
|
|
||||||
base,
|
|
||||||
is_neox_style,
|
|
||||||
dtype,
|
|
||||||
mrope_section=rope_scaling["mrope_section"],
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
_ROPE_DICT[key] = rotary_emb
|
_ROPE_DICT[key] = rotary_emb
|
||||||
|
@ -242,7 +242,7 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.o_proj")
|
prefix=f"{prefix}.o_proj")
|
||||||
rope_scaling['type'] = 'deepseek_yarn'
|
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||||
rotary_dim=qk_rope_head_dim,
|
rotary_dim=qk_rope_head_dim,
|
||||||
max_position=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
|
@ -179,7 +179,7 @@ class Phi3SmallSelfAttention(nn.Module):
|
|||||||
rope_scaling["factor"] = self.rope_position_scale
|
rope_scaling["factor"] = self.rope_position_scale
|
||||||
else:
|
else:
|
||||||
rope_scaling = {
|
rope_scaling = {
|
||||||
"type": "linear",
|
"rope_type": "linear",
|
||||||
"factor": self.rope_position_scale,
|
"factor": self.rope_position_scale,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,6 +34,8 @@ from PIL import Image
|
|||||||
from transformers.image_utils import (get_image_size,
|
from transformers.image_utils import (get_image_size,
|
||||||
infer_channel_dimension_format,
|
infer_channel_dimension_format,
|
||||||
to_numpy_array)
|
to_numpy_array)
|
||||||
|
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
||||||
|
Qwen2VLConfig, Qwen2VLVisionConfig)
|
||||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
|
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
|
||||||
make_batched_images, make_batched_videos, smart_resize)
|
make_batched_images, make_batched_videos, smart_resize)
|
||||||
|
|
||||||
@ -62,8 +64,7 @@ from vllm.multimodal.base import MultiModalData
|
|||||||
from vllm.multimodal.image import cached_get_image_processor
|
from vllm.multimodal.image import cached_get_image_processor
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors, SequenceData
|
from vllm.sequence import IntermediateTensors, SequenceData
|
||||||
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
Qwen2VLVisionConfig)
|
|
||||||
from vllm.transformers_utils.processor import get_processor
|
from vllm.transformers_utils.processor import get_processor
|
||||||
from vllm.utils import is_cpu
|
from vllm.utils import is_cpu
|
||||||
|
|
||||||
@ -1061,8 +1062,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
if image_input is None and video_input is None:
|
if image_input is None and video_input is None:
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
else:
|
else:
|
||||||
rope_scaling = getattr(self.config, "rope_scaling", {})
|
if uses_mrope(self.config):
|
||||||
if rope_scaling.get("type", None) == "mrope":
|
|
||||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||||
"multimodal section rotary embedding requires "
|
"multimodal section rotary embedding requires "
|
||||||
f"(3, seq_len) positions, but got {positions.size()}")
|
f"(3, seq_len) positions, but got {positions.size()}")
|
||||||
|
@ -23,8 +23,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
|||||||
MedusaConfig, MllamaConfig,
|
MedusaConfig, MllamaConfig,
|
||||||
MLPSpeculatorConfig, MPTConfig,
|
MLPSpeculatorConfig, MPTConfig,
|
||||||
NemotronConfig, NVLM_D_Config,
|
NemotronConfig, NVLM_D_Config,
|
||||||
Qwen2VLConfig, RWConfig,
|
RWConfig, SolarConfig,
|
||||||
SolarConfig, UltravoxConfig)
|
UltravoxConfig)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
|
|
||||||
@ -57,7 +57,6 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|||||||
"NVLM_D": NVLM_D_Config,
|
"NVLM_D": NVLM_D_Config,
|
||||||
"solar": SolarConfig,
|
"solar": SolarConfig,
|
||||||
"ultravox": UltravoxConfig,
|
"ultravox": UltravoxConfig,
|
||||||
"qwen2_vl": Qwen2VLConfig,
|
|
||||||
**_CONFIG_REGISTRY_OVERRIDE_HF
|
**_CONFIG_REGISTRY_OVERRIDE_HF
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,6 +90,43 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def patch_rope_scaling(config: PretrainedConfig) -> None:
|
||||||
|
"""Provide backwards compatibility for RoPE."""
|
||||||
|
text_config = getattr(config, "text_config", None)
|
||||||
|
if text_config is not None:
|
||||||
|
patch_rope_scaling(text_config)
|
||||||
|
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
if rope_scaling is not None:
|
||||||
|
patch_rope_scaling_dict(rope_scaling)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None:
|
||||||
|
if "rope_type" not in rope_scaling and "type" in rope_scaling:
|
||||||
|
rope_scaling["rope_type"] = rope_scaling["type"]
|
||||||
|
logger.info("Replacing legacy 'type' key with 'rope_type'")
|
||||||
|
|
||||||
|
if "rope_type" not in rope_scaling:
|
||||||
|
raise ValueError("rope_scaling should have a 'rope_type' key")
|
||||||
|
|
||||||
|
if rope_scaling["rope_type"] == "su":
|
||||||
|
rope_scaling["rope_type"] = "longrope"
|
||||||
|
logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
|
||||||
|
elif rope_scaling["rope_type"] == "mrope":
|
||||||
|
assert "mrope_section" in rope_scaling
|
||||||
|
rope_scaling["rope_type"] = "default"
|
||||||
|
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")
|
||||||
|
|
||||||
|
|
||||||
|
def uses_mrope(config: PretrainedConfig) -> bool:
|
||||||
|
"""Detect if the model with this config uses M-ROPE."""
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
if rope_scaling is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return "mrope_section" in rope_scaling
|
||||||
|
|
||||||
|
|
||||||
def get_config(
|
def get_config(
|
||||||
model: Union[str, Path],
|
model: Union[str, Path],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
@ -191,6 +227,8 @@ def get_config(
|
|||||||
)
|
)
|
||||||
config.update({key: value})
|
config.update({key: value})
|
||||||
|
|
||||||
|
patch_rope_scaling(config)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,8 +14,6 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
|||||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||||
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
||||||
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
|
|
||||||
Qwen2VLVisionConfig)
|
|
||||||
from vllm.transformers_utils.configs.solar import SolarConfig
|
from vllm.transformers_utils.configs.solar import SolarConfig
|
||||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||||
|
|
||||||
@ -35,6 +33,4 @@ __all__ = [
|
|||||||
"NVLM_D_Config",
|
"NVLM_D_Config",
|
||||||
"SolarConfig",
|
"SolarConfig",
|
||||||
"UltravoxConfig",
|
"UltravoxConfig",
|
||||||
"Qwen2VLConfig",
|
|
||||||
"Qwen2VLVisionConfig",
|
|
||||||
]
|
]
|
||||||
|
@ -1,131 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Qwen2VL model configuration"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLVisionConfig(PretrainedConfig):
|
|
||||||
model_type = "qwen2_vl"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
depth=32,
|
|
||||||
embed_dim=1280,
|
|
||||||
hidden_size=3584,
|
|
||||||
hidden_act="quick_gelu",
|
|
||||||
mlp_ratio=4,
|
|
||||||
num_heads=16,
|
|
||||||
in_channels=3,
|
|
||||||
patch_size=14,
|
|
||||||
spatial_merge_size=2,
|
|
||||||
temporal_patch_size=2,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.depth = depth
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.mlp_ratio = mlp_ratio
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.spatial_merge_size = spatial_merge_size
|
|
||||||
self.temporal_patch_size = temporal_patch_size
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str,
|
|
||||||
os.PathLike],
|
|
||||||
**kwargs) -> "PretrainedConfig":
|
|
||||||
cls._set_token_in_kwargs(kwargs)
|
|
||||||
|
|
||||||
config_dict, kwargs = cls.get_config_dict(
|
|
||||||
pretrained_model_name_or_path, **kwargs)
|
|
||||||
|
|
||||||
if config_dict.get("model_type") == "qwen2_vl":
|
|
||||||
config_dict = config_dict["vision_config"]
|
|
||||||
|
|
||||||
return cls.from_dict(config_dict, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLConfig(PretrainedConfig):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size=152064,
|
|
||||||
hidden_size=8192,
|
|
||||||
intermediate_size=29568,
|
|
||||||
num_hidden_layers=80,
|
|
||||||
num_attention_heads=64,
|
|
||||||
num_key_value_heads=8,
|
|
||||||
hidden_act="silu",
|
|
||||||
max_position_embeddings=32768,
|
|
||||||
initializer_range=0.02,
|
|
||||||
rms_norm_eps=1e-05,
|
|
||||||
use_cache=True,
|
|
||||||
tie_word_embeddings=False,
|
|
||||||
rope_theta=1000000.0,
|
|
||||||
use_sliding_window=False,
|
|
||||||
sliding_window=4096,
|
|
||||||
max_window_layers=80,
|
|
||||||
attention_dropout=0.0,
|
|
||||||
vision_config=None,
|
|
||||||
rope_scaling=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if isinstance(vision_config, dict):
|
|
||||||
self.vision_config = Qwen2VLVisionConfig(**vision_config)
|
|
||||||
elif vision_config is None:
|
|
||||||
self.vision_config = Qwen2VLVisionConfig()
|
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.use_sliding_window = use_sliding_window
|
|
||||||
self.sliding_window = sliding_window
|
|
||||||
self.max_window_layers = max_window_layers
|
|
||||||
|
|
||||||
# for backward compatibility
|
|
||||||
if num_key_value_heads is None:
|
|
||||||
num_key_value_heads = num_attention_heads
|
|
||||||
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.rms_norm_eps = rms_norm_eps
|
|
||||||
self.use_cache = use_cache
|
|
||||||
self.rope_theta = rope_theta
|
|
||||||
self.attention_dropout = attention_dropout
|
|
||||||
self.rope_scaling = rope_scaling
|
|
||||||
|
|
||||||
# NOTE: the following section from original transformers config
|
|
||||||
# for Qwen2-VL is commented out to address rope config loading issue
|
|
||||||
#
|
|
||||||
# if self.rope_scaling is not None and "type" in self.rope_scaling:
|
|
||||||
# if self.rope_scaling["type"] == "mrope":
|
|
||||||
# self.rope_scaling["type"] = "default"
|
|
||||||
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
|
||||||
# rope_config_validation(self)
|
|
||||||
|
|
||||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
|
@ -19,6 +19,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
|||||||
MultiModalInputs)
|
MultiModalInputs)
|
||||||
from vllm.sequence import (IntermediateTensors, SequenceData,
|
from vllm.sequence import (IntermediateTensors, SequenceData,
|
||||||
SequenceGroupMetadata)
|
SequenceGroupMetadata)
|
||||||
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
from vllm.utils import make_tensor_with_pad
|
from vllm.utils import make_tensor_with_pad
|
||||||
from vllm.worker.model_runner_base import (
|
from vllm.worker.model_runner_base import (
|
||||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||||
@ -439,10 +440,7 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
|
|||||||
def model_is_mrope(self) -> bool:
|
def model_is_mrope(self) -> bool:
|
||||||
"""Detect if the model has "mrope" rope_scaling type.
|
"""Detect if the model has "mrope" rope_scaling type.
|
||||||
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
||||||
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
|
return uses_mrope(self.model_config.hf_config)
|
||||||
if rope_scaling is None:
|
|
||||||
return False
|
|
||||||
return rope_scaling.get("type", None) == "mrope"
|
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
self.model = get_model(model_config=self.model_config,
|
self.model = get_model(model_config=self.model_config,
|
||||||
|
@ -47,6 +47,7 @@ from vllm.prompt_adapter.worker_manager import (
|
|||||||
LRUCacheWorkerPromptAdapterManager)
|
LRUCacheWorkerPromptAdapterManager)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||||
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
|
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
|
||||||
flatten_2d_lists, is_hip, is_pin_memory_available,
|
flatten_2d_lists, is_hip, is_pin_memory_available,
|
||||||
supports_dynamo)
|
supports_dynamo)
|
||||||
@ -1379,10 +1380,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
def model_is_mrope(self) -> bool:
|
def model_is_mrope(self) -> bool:
|
||||||
"""Detect if the model has "mrope" rope_scaling type.
|
"""Detect if the model has "mrope" rope_scaling type.
|
||||||
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
||||||
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
|
return uses_mrope(self.model_config.hf_config)
|
||||||
if rope_scaling is None:
|
|
||||||
return False
|
|
||||||
return rope_scaling.get("type", None) == "mrope"
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user