[Bugfix] fix rope error when load models with different dtypes (#4835)
This commit is contained in:
parent
26148120b3
commit
33e0823de5
@ -1,4 +1,4 @@
|
|||||||
from itertools import accumulate
|
from itertools import accumulate, product
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -207,3 +207,45 @@ def test_batched_rotary_embedding_multi_lora(
|
|||||||
ref_key,
|
ref_key,
|
||||||
atol=get_default_atol(out_key),
|
atol=get_default_atol(out_key),
|
||||||
rtol=get_default_rtol(out_key))
|
rtol=get_default_rtol(out_key))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_rope_module_cache():
|
||||||
|
MAX_POSITIONS = [123, 1234]
|
||||||
|
BASES = [10000, 1000000]
|
||||||
|
ROPE_SCALINGS = [
|
||||||
|
None, {
|
||||||
|
"type": "linear",
|
||||||
|
"factor": (1, )
|
||||||
|
}, {
|
||||||
|
"type": "dynamic",
|
||||||
|
"factor": 1
|
||||||
|
}
|
||||||
|
]
|
||||||
|
settings = [
|
||||||
|
HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
|
||||||
|
ROPE_SCALINGS, DTYPES
|
||||||
|
]
|
||||||
|
rope_setting_id_map = {}
|
||||||
|
for setting in product(*settings):
|
||||||
|
head_size, rotary_dim, max_position, base, \
|
||||||
|
is_neox_stype, rope_scaling, dtype = setting
|
||||||
|
if rotary_dim is None:
|
||||||
|
rotary_dim = head_size
|
||||||
|
rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||||
|
is_neox_stype, rope_scaling, dtype)
|
||||||
|
# different settings cannot share the same rope module
|
||||||
|
assert id(rope) not in rope_setting_id_map.values()
|
||||||
|
assert all(x.dtype == dtype for x in rope.buffers())
|
||||||
|
assert all(x.dtype == dtype for x in rope.parameters())
|
||||||
|
rope_setting_id_map[str(setting)] = id(rope)
|
||||||
|
|
||||||
|
for setting in product(*settings):
|
||||||
|
head_size, rotary_dim, max_position, base, \
|
||||||
|
is_neox_stype, rope_scaling, dtype = setting
|
||||||
|
if rotary_dim is None:
|
||||||
|
rotary_dim = head_size
|
||||||
|
rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||||
|
is_neox_stype, rope_scaling, dtype)
|
||||||
|
# check if cache take effect
|
||||||
|
assert id(rope) == rope_setting_id_map[str(setting)]
|
||||||
|
@ -53,6 +53,7 @@ class RotaryEmbedding(nn.Module):
|
|||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
base: int,
|
base: int,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@ -62,7 +63,7 @@ class RotaryEmbedding(nn.Module):
|
|||||||
self.is_neox_style = is_neox_style
|
self.is_neox_style = is_neox_style
|
||||||
|
|
||||||
cache = self._compute_cos_sin_cache()
|
cache = self._compute_cos_sin_cache()
|
||||||
cache = cache.to(torch.get_default_dtype())
|
cache = cache.to(dtype)
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||||
@ -178,12 +179,13 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
base: int,
|
base: int,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
scaling_factors: Union[List[float], float],
|
scaling_factors: Union[List[float], float],
|
||||||
|
dtype: torch.dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(scaling_factors, float):
|
if isinstance(scaling_factors, float):
|
||||||
scaling_factors = [scaling_factors]
|
scaling_factors = [scaling_factors]
|
||||||
self.scaling_factors = scaling_factors
|
self.scaling_factors = scaling_factors
|
||||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
is_neox_style)
|
is_neox_style, dtype)
|
||||||
|
|
||||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
inv_freq = self._compute_inv_freq(self.base)
|
inv_freq = self._compute_inv_freq(self.base)
|
||||||
@ -219,10 +221,11 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
base: int,
|
base: int,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
scaling_factor: float,
|
scaling_factor: float,
|
||||||
|
dtype: torch.dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
is_neox_style)
|
is_neox_style, dtype)
|
||||||
|
|
||||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
# NOTE(woosuk): self.max_position_embeddings is the original
|
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||||
@ -299,6 +302,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
base: int,
|
base: int,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
scaling_factor: float,
|
scaling_factor: float,
|
||||||
|
dtype: torch.dtype,
|
||||||
*,
|
*,
|
||||||
extrapolation_factor: float = 1,
|
extrapolation_factor: float = 1,
|
||||||
attn_factor: float = 1,
|
attn_factor: float = 1,
|
||||||
@ -314,7 +318,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
self.mscale = float(
|
self.mscale = float(
|
||||||
_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
is_neox_style)
|
is_neox_style, dtype)
|
||||||
|
|
||||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||||
pos_freqs = self.base**(
|
pos_freqs = self.base**(
|
||||||
@ -359,6 +363,7 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
|
|||||||
original_max_position_embeddings: int,
|
original_max_position_embeddings: int,
|
||||||
base: int,
|
base: int,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
short_factor: List[float],
|
short_factor: List[float],
|
||||||
long_factor: List[float],
|
long_factor: List[float],
|
||||||
short_mscale: float = 1.1,
|
short_mscale: float = 1.1,
|
||||||
@ -385,14 +390,14 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
|
|||||||
|
|
||||||
short_cache = self._compute_cos_sin_cache(
|
short_cache = self._compute_cos_sin_cache(
|
||||||
original_max_position_embeddings, short_factor, short_mscale)
|
original_max_position_embeddings, short_factor, short_mscale)
|
||||||
short_cache = short_cache.to(torch.get_default_dtype())
|
short_cache = short_cache.to(dtype)
|
||||||
self.register_buffer("short_cos_sin_cache",
|
self.register_buffer("short_cos_sin_cache",
|
||||||
short_cache,
|
short_cache,
|
||||||
persistent=False)
|
persistent=False)
|
||||||
|
|
||||||
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
|
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
|
||||||
long_factor, long_mscale)
|
long_factor, long_mscale)
|
||||||
long_cache = long_cache.to(torch.get_default_dtype())
|
long_cache = long_cache.to(dtype)
|
||||||
self.register_buffer("long_cos_sin_cache",
|
self.register_buffer("long_cos_sin_cache",
|
||||||
long_cache,
|
long_cache,
|
||||||
persistent=False)
|
persistent=False)
|
||||||
@ -463,7 +468,10 @@ def get_rope(
|
|||||||
base: int,
|
base: int,
|
||||||
is_neox_style: bool = True,
|
is_neox_style: bool = True,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
) -> RotaryEmbedding:
|
) -> RotaryEmbedding:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
if rope_scaling is not None:
|
if rope_scaling is not None:
|
||||||
# Transforms every value that is a list into a tuple for caching calls
|
# Transforms every value that is a list into a tuple for caching calls
|
||||||
rope_scaling_tuple = {
|
rope_scaling_tuple = {
|
||||||
@ -474,12 +482,12 @@ def get_rope(
|
|||||||
else:
|
else:
|
||||||
rope_scaling_args = None
|
rope_scaling_args = None
|
||||||
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
rope_scaling_args)
|
rope_scaling_args, dtype)
|
||||||
if key in _ROPE_DICT:
|
if key in _ROPE_DICT:
|
||||||
return _ROPE_DICT[key]
|
return _ROPE_DICT[key]
|
||||||
if rope_scaling is None:
|
if rope_scaling is None:
|
||||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||||
is_neox_style)
|
is_neox_style, dtype)
|
||||||
else:
|
else:
|
||||||
scaling_type = rope_scaling["type"]
|
scaling_type = rope_scaling["type"]
|
||||||
if scaling_type != "su":
|
if scaling_type != "su":
|
||||||
@ -488,11 +496,11 @@ def get_rope(
|
|||||||
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||||
max_position, base,
|
max_position, base,
|
||||||
is_neox_style,
|
is_neox_style,
|
||||||
scaling_factor)
|
scaling_factor, dtype)
|
||||||
elif scaling_type == "dynamic":
|
elif scaling_type == "dynamic":
|
||||||
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
scaling_factor)
|
scaling_factor, dtype)
|
||||||
elif scaling_type == "yarn":
|
elif scaling_type == "yarn":
|
||||||
original_max_position = rope_scaling[
|
original_max_position = rope_scaling[
|
||||||
"original_max_position_embeddings"]
|
"original_max_position_embeddings"]
|
||||||
@ -505,7 +513,7 @@ def get_rope(
|
|||||||
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
||||||
original_max_position,
|
original_max_position,
|
||||||
base, is_neox_style,
|
base, is_neox_style,
|
||||||
scaling_factor,
|
scaling_factor, dtype,
|
||||||
**extra_kwargs)
|
**extra_kwargs)
|
||||||
elif scaling_type == "su":
|
elif scaling_type == "su":
|
||||||
short_factor = rope_scaling["short_factor"]
|
short_factor = rope_scaling["short_factor"]
|
||||||
@ -519,7 +527,8 @@ def get_rope(
|
|||||||
}
|
}
|
||||||
rotary_emb = Phi3SuScaledRotaryEmbedding(
|
rotary_emb = Phi3SuScaledRotaryEmbedding(
|
||||||
head_size, rotary_dim, max_position, original_max_position,
|
head_size, rotary_dim, max_position, original_max_position,
|
||||||
base, is_neox_style, short_factor, long_factor, **extra_kwargs)
|
base, is_neox_style, dtype, short_factor, long_factor,
|
||||||
|
**extra_kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
_ROPE_DICT[key] = rotary_emb
|
_ROPE_DICT[key] = rotary_emb
|
||||||
|
Loading…
x
Reference in New Issue
Block a user