diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index fa85f72e..e6ee2b96 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -734,34 +734,50 @@ class GemmaRotaryEmbedding(RotaryEmbedding): return inv_freq -class ExtendedRotaryEmbedding(RotaryEmbedding): +class Llama3RotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + ) -> None: + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: inv_freqs = super()._compute_inv_freq(base) - return self.apply_scaling(inv_freqs) + low_freq_wavelen = self.orig_max_position / self.low_freq_factor + high_freq_wavelen = self.orig_max_position / self.high_freq_factor - def apply_scaling(self, freqs: torch.Tensor): - scale_factor = 8 - low_freq_factor = 1 - high_freq_factor = 4 - old_context_len = 8192 - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in freqs: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / scale_factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / ( - high_freq_factor - low_freq_factor) - new_freqs.append((1 - smooth) * freq / scale_factor + - smooth * freq) - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + wave_len = 2 * math.pi / inv_freqs + if self.low_freq_factor != self.high_freq_factor: + smooth = (self.orig_max_position / wave_len - self.low_freq_factor + ) / (self.high_freq_factor - self.low_freq_factor) + else: + smooth = 0 + new_freqs = torch.where( + wave_len < high_freq_wavelen, + inv_freqs, + torch.where( + wave_len > low_freq_wavelen, + inv_freqs / self.scaling_factor, + (1 - smooth) * inv_freqs / self.scaling_factor + + smooth * inv_freqs, + ), + ) + return new_freqs _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -794,6 +810,7 @@ def get_rope( rope_scaling_args, dtype) if key in _ROPE_DICT: return _ROPE_DICT[key] + if rope_scaling is None: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype) @@ -802,12 +819,19 @@ def get_rope( "type"] if "type" in rope_scaling else rope_scaling["rope_type"] # The correct one should be "longrope" but keep "su" here # for backward compatible - if scaling_type not in {"su", "longrope", "llama3"}: + if scaling_type not in {"su", "longrope"}: scaling_factor = rope_scaling["factor"] if scaling_type == "llama3": - rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, dtype) + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype, + scaling_factor, low_freq_factor, + high_freq_factor, + original_max_position) elif scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base,