[Misc] Refactor Llama3 RoPE initialization (#7637)
This commit is contained in:
parent
40e1360bb6
commit
200a2ffa6b
@ -734,34 +734,50 @@ class GemmaRotaryEmbedding(RotaryEmbedding):
|
|||||||
return inv_freq
|
return inv_freq
|
||||||
|
|
||||||
|
|
||||||
class ExtendedRotaryEmbedding(RotaryEmbedding):
|
class Llama3RotaryEmbedding(RotaryEmbedding):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
scaling_factor: float,
|
||||||
|
low_freq_factor: float,
|
||||||
|
high_freq_factor: float,
|
||||||
|
orig_max_position: int,
|
||||||
|
) -> None:
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.low_freq_factor = low_freq_factor
|
||||||
|
self.high_freq_factor = high_freq_factor
|
||||||
|
self.orig_max_position = orig_max_position
|
||||||
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
|
is_neox_style, dtype)
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||||
inv_freqs = super()._compute_inv_freq(base)
|
inv_freqs = super()._compute_inv_freq(base)
|
||||||
return self.apply_scaling(inv_freqs)
|
low_freq_wavelen = self.orig_max_position / self.low_freq_factor
|
||||||
|
high_freq_wavelen = self.orig_max_position / self.high_freq_factor
|
||||||
|
|
||||||
def apply_scaling(self, freqs: torch.Tensor):
|
wave_len = 2 * math.pi / inv_freqs
|
||||||
scale_factor = 8
|
if self.low_freq_factor != self.high_freq_factor:
|
||||||
low_freq_factor = 1
|
smooth = (self.orig_max_position / wave_len - self.low_freq_factor
|
||||||
high_freq_factor = 4
|
) / (self.high_freq_factor - self.low_freq_factor)
|
||||||
old_context_len = 8192
|
|
||||||
|
|
||||||
low_freq_wavelen = old_context_len / low_freq_factor
|
|
||||||
high_freq_wavelen = old_context_len / high_freq_factor
|
|
||||||
new_freqs = []
|
|
||||||
for freq in freqs:
|
|
||||||
wavelen = 2 * math.pi / freq
|
|
||||||
if wavelen < high_freq_wavelen:
|
|
||||||
new_freqs.append(freq)
|
|
||||||
elif wavelen > low_freq_wavelen:
|
|
||||||
new_freqs.append(freq / scale_factor)
|
|
||||||
else:
|
else:
|
||||||
assert low_freq_wavelen != high_freq_wavelen
|
smooth = 0
|
||||||
smooth = (old_context_len / wavelen - low_freq_factor) / (
|
new_freqs = torch.where(
|
||||||
high_freq_factor - low_freq_factor)
|
wave_len < high_freq_wavelen,
|
||||||
new_freqs.append((1 - smooth) * freq / scale_factor +
|
inv_freqs,
|
||||||
smooth * freq)
|
torch.where(
|
||||||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
wave_len > low_freq_wavelen,
|
||||||
|
inv_freqs / self.scaling_factor,
|
||||||
|
(1 - smooth) * inv_freqs / self.scaling_factor +
|
||||||
|
smooth * inv_freqs,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return new_freqs
|
||||||
|
|
||||||
|
|
||||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||||
@ -794,6 +810,7 @@ def get_rope(
|
|||||||
rope_scaling_args, dtype)
|
rope_scaling_args, dtype)
|
||||||
if key in _ROPE_DICT:
|
if key in _ROPE_DICT:
|
||||||
return _ROPE_DICT[key]
|
return _ROPE_DICT[key]
|
||||||
|
|
||||||
if rope_scaling is None:
|
if rope_scaling is None:
|
||||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||||
is_neox_style, dtype)
|
is_neox_style, dtype)
|
||||||
@ -802,12 +819,19 @@ def get_rope(
|
|||||||
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
|
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
|
||||||
# The correct one should be "longrope" but keep "su" here
|
# The correct one should be "longrope" but keep "su" here
|
||||||
# for backward compatible
|
# for backward compatible
|
||||||
if scaling_type not in {"su", "longrope", "llama3"}:
|
if scaling_type not in {"su", "longrope"}:
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
if scaling_type == "llama3":
|
if scaling_type == "llama3":
|
||||||
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
|
low_freq_factor = rope_scaling["low_freq_factor"]
|
||||||
|
high_freq_factor = rope_scaling["high_freq_factor"]
|
||||||
|
original_max_position = rope_scaling[
|
||||||
|
"original_max_position_embeddings"]
|
||||||
|
rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
|
||||||
max_position, base,
|
max_position, base,
|
||||||
is_neox_style, dtype)
|
is_neox_style, dtype,
|
||||||
|
scaling_factor, low_freq_factor,
|
||||||
|
high_freq_factor,
|
||||||
|
original_max_position)
|
||||||
elif scaling_type == "linear":
|
elif scaling_type == "linear":
|
||||||
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||||
max_position, base,
|
max_position, base,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user