[Minor] Move RoPE selection logic to get_rope (#1633)

This commit is contained in:
Woosuk Kwon 2023-11-12 16:04:50 -08:00 committed by GitHub
parent eb825c1e74
commit 054072bee5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 34 deletions

View File

@ -10,9 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from vllm import attention_ops from vllm import attention_ops
from vllm import cache_ops from vllm import cache_ops
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import get_rope
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
RotaryEmbedding, YaRNScalingRotaryEmbedding)
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
@ -319,36 +317,8 @@ class PagedAttentionWithRoPE(PagedAttention):
scale, scale,
num_kv_heads, num_kv_heads,
sliding_window=sliding_window) sliding_window=sliding_window)
if rope_scaling is None: self.rotary_emb = get_rope(head_size, rotary_dim, max_position, base,
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim, is_neox_style, rope_scaling)
max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LinearScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "dynamic":
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor",
"beta_fast", "beta_slow")
}
self.rotary_emb = YaRNScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, **extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward( def forward(
self, self,

View File

@ -22,7 +22,7 @@
# limitations under the License. # limitations under the License.
"""Rotary Positional Embeddings.""" """Rotary Positional Embeddings."""
import math import math
from typing import Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -271,3 +271,46 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
sin = (freqs.sin() * self.mscale) sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool,
rope_scaling: Optional[Dict[str, Any]],
) -> RotaryEmbedding:
if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor)
elif scaling_type == "dynamic":
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow")
}
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
base, is_neox_style,
scaling_factor,
**extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
return rotary_emb