[Misc] Refactor Llama3 RoPE initialization (#7637)

This commit is contained in:
Woosuk Kwon 2024-08-18 17:18:12 -07:00 committed by GitHub
parent 40e1360bb6
commit 200a2ffa6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -734,34 +734,50 @@ class GemmaRotaryEmbedding(RotaryEmbedding):
return inv_freq return inv_freq
class ExtendedRotaryEmbedding(RotaryEmbedding): class Llama3RotaryEmbedding(RotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
scaling_factor: float,
low_freq_factor: float,
high_freq_factor: float,
orig_max_position: int,
) -> None:
self.scaling_factor = scaling_factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base) inv_freqs = super()._compute_inv_freq(base)
return self.apply_scaling(inv_freqs) low_freq_wavelen = self.orig_max_position / self.low_freq_factor
high_freq_wavelen = self.orig_max_position / self.high_freq_factor
def apply_scaling(self, freqs: torch.Tensor): wave_len = 2 * math.pi / inv_freqs
scale_factor = 8 if self.low_freq_factor != self.high_freq_factor:
low_freq_factor = 1 smooth = (self.orig_max_position / wave_len - self.low_freq_factor
high_freq_factor = 4 ) / (self.high_freq_factor - self.low_freq_factor)
old_context_len = 8192
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else: else:
assert low_freq_wavelen != high_freq_wavelen smooth = 0
smooth = (old_context_len / wavelen - low_freq_factor) / ( new_freqs = torch.where(
high_freq_factor - low_freq_factor) wave_len < high_freq_wavelen,
new_freqs.append((1 - smooth) * freq / scale_factor + inv_freqs,
smooth * freq) torch.where(
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) wave_len > low_freq_wavelen,
inv_freqs / self.scaling_factor,
(1 - smooth) * inv_freqs / self.scaling_factor +
smooth * inv_freqs,
),
)
return new_freqs
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
@ -794,6 +810,7 @@ def get_rope(
rope_scaling_args, dtype) rope_scaling_args, dtype)
if key in _ROPE_DICT: if key in _ROPE_DICT:
return _ROPE_DICT[key] return _ROPE_DICT[key]
if rope_scaling is None: if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype) is_neox_style, dtype)
@ -802,12 +819,19 @@ def get_rope(
"type"] if "type" in rope_scaling else rope_scaling["rope_type"] "type"] if "type" in rope_scaling else rope_scaling["rope_type"]
# The correct one should be "longrope" but keep "su" here # The correct one should be "longrope" but keep "su" here
# for backward compatible # for backward compatible
if scaling_type not in {"su", "longrope", "llama3"}: if scaling_type not in {"su", "longrope"}:
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
if scaling_type == "llama3": if scaling_type == "llama3":
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim, low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
max_position, base, max_position, base,
is_neox_style, dtype) is_neox_style, dtype,
scaling_factor, low_freq_factor,
high_freq_factor,
original_max_position)
elif scaling_type == "linear": elif scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base, max_position, base,