[Minor] Move RoPE selection logic to get_rope
(#1633)
This commit is contained in:
parent
eb825c1e74
commit
054072bee5
@ -10,9 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
||||
from vllm import attention_ops
|
||||
from vllm import cache_ops
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
|
||||
RotaryEmbedding, YaRNScalingRotaryEmbedding)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
@ -319,36 +317,8 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
scale,
|
||||
num_kv_heads,
|
||||
sliding_window=sliding_window)
|
||||
if rope_scaling is None:
|
||||
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style)
|
||||
else:
|
||||
scaling_type = rope_scaling["type"]
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
self.rotary_emb = LinearScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
scaling_factor)
|
||||
elif scaling_type == "dynamic":
|
||||
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
scaling_factor)
|
||||
elif scaling_type == "yarn":
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("extrapolation_factor", "attn_factor",
|
||||
"beta_fast", "beta_slow")
|
||||
}
|
||||
self.rotary_emb = YaRNScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, original_max_position, base,
|
||||
is_neox_style, scaling_factor, **extra_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
self.rotary_emb = get_rope(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, rope_scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -22,7 +22,7 @@
|
||||
# limitations under the License.
|
||||
"""Rotary Positional Embeddings."""
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -271,3 +271,46 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
sin = (freqs.sin() * self.mscale)
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
|
||||
def get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
rope_scaling: Optional[Dict[str, Any]],
|
||||
) -> RotaryEmbedding:
|
||||
if rope_scaling is None:
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style)
|
||||
else:
|
||||
scaling_type = rope_scaling["type"]
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style,
|
||||
scaling_factor)
|
||||
elif scaling_type == "dynamic":
|
||||
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
scaling_factor)
|
||||
elif scaling_type == "yarn":
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||
"beta_slow")
|
||||
}
|
||||
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
original_max_position,
|
||||
base, is_neox_style,
|
||||
scaling_factor,
|
||||
**extra_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
return rotary_emb
|
||||
|
Loading…
x
Reference in New Issue
Block a user