[Bugfix] fix rope error when load models with different dtypes (#4835)
This commit is contained in:
parent
26148120b3
commit
33e0823de5
@ -1,4 +1,4 @@
|
||||
from itertools import accumulate
|
||||
from itertools import accumulate, product
|
||||
from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
@ -207,3 +207,45 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_rope_module_cache():
|
||||
MAX_POSITIONS = [123, 1234]
|
||||
BASES = [10000, 1000000]
|
||||
ROPE_SCALINGS = [
|
||||
None, {
|
||||
"type": "linear",
|
||||
"factor": (1, )
|
||||
}, {
|
||||
"type": "dynamic",
|
||||
"factor": 1
|
||||
}
|
||||
]
|
||||
settings = [
|
||||
HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
|
||||
ROPE_SCALINGS, DTYPES
|
||||
]
|
||||
rope_setting_id_map = {}
|
||||
for setting in product(*settings):
|
||||
head_size, rotary_dim, max_position, base, \
|
||||
is_neox_stype, rope_scaling, dtype = setting
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||
is_neox_stype, rope_scaling, dtype)
|
||||
# different settings cannot share the same rope module
|
||||
assert id(rope) not in rope_setting_id_map.values()
|
||||
assert all(x.dtype == dtype for x in rope.buffers())
|
||||
assert all(x.dtype == dtype for x in rope.parameters())
|
||||
rope_setting_id_map[str(setting)] = id(rope)
|
||||
|
||||
for setting in product(*settings):
|
||||
head_size, rotary_dim, max_position, base, \
|
||||
is_neox_stype, rope_scaling, dtype = setting
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||
is_neox_stype, rope_scaling, dtype)
|
||||
# check if cache take effect
|
||||
assert id(rope) == rope_setting_id_map[str(setting)]
|
||||
|
@ -53,6 +53,7 @@ class RotaryEmbedding(nn.Module):
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
@ -62,7 +63,7 @@ class RotaryEmbedding(nn.Module):
|
||||
self.is_neox_style = is_neox_style
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(torch.get_default_dtype())
|
||||
cache = cache.to(dtype)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
@ -178,12 +179,13 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factors: Union[List[float], float],
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
if isinstance(scaling_factors, float):
|
||||
scaling_factors = [scaling_factors]
|
||||
self.scaling_factors = scaling_factors
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style)
|
||||
is_neox_style, dtype)
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
@ -219,10 +221,11 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style)
|
||||
is_neox_style, dtype)
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||
@ -299,6 +302,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
@ -314,7 +318,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
self.mscale = float(
|
||||
_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style)
|
||||
is_neox_style, dtype)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base**(
|
||||
@ -359,6 +363,7 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
|
||||
original_max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
short_factor: List[float],
|
||||
long_factor: List[float],
|
||||
short_mscale: float = 1.1,
|
||||
@ -385,14 +390,14 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
|
||||
|
||||
short_cache = self._compute_cos_sin_cache(
|
||||
original_max_position_embeddings, short_factor, short_mscale)
|
||||
short_cache = short_cache.to(torch.get_default_dtype())
|
||||
short_cache = short_cache.to(dtype)
|
||||
self.register_buffer("short_cos_sin_cache",
|
||||
short_cache,
|
||||
persistent=False)
|
||||
|
||||
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
|
||||
long_factor, long_mscale)
|
||||
long_cache = long_cache.to(torch.get_default_dtype())
|
||||
long_cache = long_cache.to(dtype)
|
||||
self.register_buffer("long_cos_sin_cache",
|
||||
long_cache,
|
||||
persistent=False)
|
||||
@ -463,7 +468,10 @@ def get_rope(
|
||||
base: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if rope_scaling is not None:
|
||||
# Transforms every value that is a list into a tuple for caching calls
|
||||
rope_scaling_tuple = {
|
||||
@ -474,12 +482,12 @@ def get_rope(
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
rope_scaling_args)
|
||||
rope_scaling_args, dtype)
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
if rope_scaling is None:
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style)
|
||||
is_neox_style, dtype)
|
||||
else:
|
||||
scaling_type = rope_scaling["type"]
|
||||
if scaling_type != "su":
|
||||
@ -488,11 +496,11 @@ def get_rope(
|
||||
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style,
|
||||
scaling_factor)
|
||||
scaling_factor, dtype)
|
||||
elif scaling_type == "dynamic":
|
||||
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
scaling_factor)
|
||||
scaling_factor, dtype)
|
||||
elif scaling_type == "yarn":
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
@ -505,7 +513,7 @@ def get_rope(
|
||||
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
original_max_position,
|
||||
base, is_neox_style,
|
||||
scaling_factor,
|
||||
scaling_factor, dtype,
|
||||
**extra_kwargs)
|
||||
elif scaling_type == "su":
|
||||
short_factor = rope_scaling["short_factor"]
|
||||
@ -519,7 +527,8 @@ def get_rope(
|
||||
}
|
||||
rotary_emb = Phi3SuScaledRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, original_max_position,
|
||||
base, is_neox_style, short_factor, long_factor, **extra_kwargs)
|
||||
base, is_neox_style, dtype, short_factor, long_factor,
|
||||
**extra_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
_ROPE_DICT[key] = rotary_emb
|
||||
|
Loading…
x
Reference in New Issue
Block a user