[Misc] Remove Gemma RoPE (#7638)
This commit is contained in:
parent
1a36287b89
commit
df845b2b46
@ -93,11 +93,6 @@ class RotaryEmbedding(CustomOp):
|
|||||||
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||||
"""Compute the inverse frequency."""
|
"""Compute the inverse frequency."""
|
||||||
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
|
||||||
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
|
||||||
# avoid numerical issues with large base values (e.g., 10000000).
|
|
||||||
# This may cause a slight numerical difference between the HF
|
|
||||||
# implementation and ours.
|
|
||||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||||
# use CPU to compute the cache and then move it to GPU. However, we
|
# use CPU to compute the cache and then move it to GPU. However, we
|
||||||
# create the cache on GPU for faster initialization. This may cause
|
# create the cache on GPU for faster initialization. This may cause
|
||||||
@ -724,16 +719,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|
||||||
class GemmaRotaryEmbedding(RotaryEmbedding):
|
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
|
||||||
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
|
|
||||||
inv_freq = 1.0 / (base**(
|
|
||||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() /
|
|
||||||
self.rotary_dim))
|
|
||||||
return inv_freq
|
|
||||||
|
|
||||||
|
|
||||||
class Llama3RotaryEmbedding(RotaryEmbedding):
|
class Llama3RotaryEmbedding(RotaryEmbedding):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
@ -148,14 +148,12 @@ class GemmaAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(woosuk): Use the `get_rope` interface.
|
self.rotary_emb = get_rope(
|
||||||
self.rotary_emb = GemmaRotaryEmbedding(
|
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
dtype=torch.get_default_dtype(),
|
|
||||||
)
|
)
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
@ -130,14 +130,12 @@ class Gemma2Attention(nn.Module):
|
|||||||
bias=config.attention_bias,
|
bias=config.attention_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
# TODO(woosuk): Use the `get_rope` interface.
|
self.rotary_emb = get_rope(
|
||||||
self.rotary_emb = GemmaRotaryEmbedding(
|
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
dtype=torch.get_default_dtype(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
|
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
|
||||||
|
Loading…
x
Reference in New Issue
Block a user