[Minor] Move RoPE selection logic to get_rope (#1633)

This commit is contained in:
Woosuk Kwon 2023-11-12 16:04:50 -08:00 committed by GitHub
parent eb825c1e74
commit 054072bee5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 34 deletions

View File

@ -10,9 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from vllm import attention_ops
from vllm import cache_ops
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.rotary_embedding import (
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
RotaryEmbedding, YaRNScalingRotaryEmbedding)
from vllm.model_executor.layers.rotary_embedding import get_rope
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
@ -319,36 +317,8 @@ class PagedAttentionWithRoPE(PagedAttention):
scale,
num_kv_heads,
sliding_window=sliding_window)
if rope_scaling is None:
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LinearScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "dynamic":
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor",
"beta_fast", "beta_slow")
}
self.rotary_emb = YaRNScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, **extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
self.rotary_emb = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, rope_scaling)
def forward(
self,

View File

@ -22,7 +22,7 @@
# limitations under the License.
"""Rotary Positional Embeddings."""
import math
from typing import Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -271,3 +271,46 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
return cache
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool,
rope_scaling: Optional[Dict[str, Any]],
) -> RotaryEmbedding:
if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor)
elif scaling_type == "dynamic":
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow")
}
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
base, is_neox_style,
scaling_factor,
**extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
return rotary_emb