[Bugfix] fix rope error when load models with different dtypes (#4835)

This commit is contained in:
Jinzhen Lin 2024-05-17 17:43:34 +08:00 committed by GitHub
parent 26148120b3
commit 33e0823de5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 13 deletions

View File

@ -1,4 +1,4 @@
from itertools import accumulate
from itertools import accumulate, product
from typing import List, Optional
import pytest
@ -207,3 +207,45 @@ def test_batched_rotary_embedding_multi_lora(
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
@torch.inference_mode()
def test_rope_module_cache():
MAX_POSITIONS = [123, 1234]
BASES = [10000, 1000000]
ROPE_SCALINGS = [
None, {
"type": "linear",
"factor": (1, )
}, {
"type": "dynamic",
"factor": 1
}
]
settings = [
HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
ROPE_SCALINGS, DTYPES
]
rope_setting_id_map = {}
for setting in product(*settings):
head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_stype, rope_scaling, dtype)
# different settings cannot share the same rope module
assert id(rope) not in rope_setting_id_map.values()
assert all(x.dtype == dtype for x in rope.buffers())
assert all(x.dtype == dtype for x in rope.parameters())
rope_setting_id_map[str(setting)] = id(rope)
for setting in product(*settings):
head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_stype, rope_scaling, dtype)
# check if cache take effect
assert id(rope) == rope_setting_id_map[str(setting)]

View File

@ -53,6 +53,7 @@ class RotaryEmbedding(nn.Module):
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
@ -62,7 +63,7 @@ class RotaryEmbedding(nn.Module):
self.is_neox_style = is_neox_style
cache = self._compute_cos_sin_cache()
cache = cache.to(torch.get_default_dtype())
cache = cache.to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
@ -178,12 +179,13 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
base: int,
is_neox_style: bool,
scaling_factors: Union[List[float], float],
dtype: torch.dtype,
) -> None:
if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors = scaling_factors
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
is_neox_style, dtype)
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
@ -219,10 +221,11 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
) -> None:
self.scaling_factor = scaling_factor
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
is_neox_style, dtype)
def _compute_cos_sin_cache(self) -> torch.Tensor:
# NOTE(woosuk): self.max_position_embeddings is the original
@ -299,6 +302,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
@ -314,7 +318,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
self.mscale = float(
_yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
is_neox_style, dtype)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(
@ -359,6 +363,7 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
original_max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
short_factor: List[float],
long_factor: List[float],
short_mscale: float = 1.1,
@ -385,14 +390,14 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(torch.get_default_dtype())
short_cache = short_cache.to(dtype)
self.register_buffer("short_cos_sin_cache",
short_cache,
persistent=False)
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
long_factor, long_mscale)
long_cache = long_cache.to(torch.get_default_dtype())
long_cache = long_cache.to(dtype)
self.register_buffer("long_cos_sin_cache",
long_cache,
persistent=False)
@ -463,7 +468,10 @@ def get_rope(
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
@ -474,12 +482,12 @@ def get_rope(
else:
rope_scaling_args = None
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args)
rope_scaling_args, dtype)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style)
is_neox_style, dtype)
else:
scaling_type = rope_scaling["type"]
if scaling_type != "su":
@ -488,11 +496,11 @@ def get_rope(
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor)
scaling_factor, dtype)
elif scaling_type == "dynamic":
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
scaling_factor, dtype)
elif scaling_type == "yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
@ -505,7 +513,7 @@ def get_rope(
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
base, is_neox_style,
scaling_factor,
scaling_factor, dtype,
**extra_kwargs)
elif scaling_type == "su":
short_factor = rope_scaling["short_factor"]
@ -519,7 +527,8 @@ def get_rope(
}
rotary_emb = Phi3SuScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, short_factor, long_factor, **extra_kwargs)
base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb