[Minor] Move RoPE selection logic to get_rope
(#1633)
This commit is contained in:
parent
eb825c1e74
commit
054072bee5
@ -10,9 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
|||||||
from vllm import attention_ops
|
from vllm import attention_ops
|
||||||
from vllm import cache_ops
|
from vllm import cache_ops
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.layers.rotary_embedding import (
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
|
|
||||||
RotaryEmbedding, YaRNScalingRotaryEmbedding)
|
|
||||||
|
|
||||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||||
@ -319,36 +317,8 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
scale,
|
scale,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
sliding_window=sliding_window)
|
sliding_window=sliding_window)
|
||||||
if rope_scaling is None:
|
self.rotary_emb = get_rope(head_size, rotary_dim, max_position, base,
|
||||||
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
|
is_neox_style, rope_scaling)
|
||||||
max_position, base,
|
|
||||||
is_neox_style)
|
|
||||||
else:
|
|
||||||
scaling_type = rope_scaling["type"]
|
|
||||||
scaling_factor = rope_scaling["factor"]
|
|
||||||
if scaling_type == "linear":
|
|
||||||
self.rotary_emb = LinearScalingRotaryEmbedding(
|
|
||||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
|
||||||
scaling_factor)
|
|
||||||
elif scaling_type == "dynamic":
|
|
||||||
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
|
||||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
|
||||||
scaling_factor)
|
|
||||||
elif scaling_type == "yarn":
|
|
||||||
original_max_position = rope_scaling[
|
|
||||||
"original_max_position_embeddings"]
|
|
||||||
assert max_position == original_max_position * scaling_factor
|
|
||||||
extra_kwargs = {
|
|
||||||
k: v
|
|
||||||
for k, v in rope_scaling.items()
|
|
||||||
if k in ("extrapolation_factor", "attn_factor",
|
|
||||||
"beta_fast", "beta_slow")
|
|
||||||
}
|
|
||||||
self.rotary_emb = YaRNScalingRotaryEmbedding(
|
|
||||||
head_size, rotary_dim, original_max_position, base,
|
|
||||||
is_neox_style, scaling_factor, **extra_kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Rotary Positional Embeddings."""
|
"""Rotary Positional Embeddings."""
|
||||||
import math
|
import math
|
||||||
from typing import Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -271,3 +271,46 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
sin = (freqs.sin() * self.mscale)
|
sin = (freqs.sin() * self.mscale)
|
||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
def get_rope(
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]],
|
||||||
|
) -> RotaryEmbedding:
|
||||||
|
if rope_scaling is None:
|
||||||
|
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||||
|
is_neox_style)
|
||||||
|
else:
|
||||||
|
scaling_type = rope_scaling["type"]
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
if scaling_type == "linear":
|
||||||
|
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||||
|
max_position, base,
|
||||||
|
is_neox_style,
|
||||||
|
scaling_factor)
|
||||||
|
elif scaling_type == "dynamic":
|
||||||
|
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||||
|
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
|
scaling_factor)
|
||||||
|
elif scaling_type == "yarn":
|
||||||
|
original_max_position = rope_scaling[
|
||||||
|
"original_max_position_embeddings"]
|
||||||
|
assert max_position == original_max_position * scaling_factor
|
||||||
|
extra_kwargs = {
|
||||||
|
k: v
|
||||||
|
for k, v in rope_scaling.items()
|
||||||
|
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||||
|
"beta_slow")
|
||||||
|
}
|
||||||
|
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
||||||
|
original_max_position,
|
||||||
|
base, is_neox_style,
|
||||||
|
scaling_factor,
|
||||||
|
**extra_kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
|
return rotary_emb
|
||||||
|
Loading…
x
Reference in New Issue
Block a user