diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 18c8e351..076730cd 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -1,4 +1,4 @@ -from itertools import accumulate +from itertools import accumulate, product from typing import List, Optional import pytest @@ -207,3 +207,45 @@ def test_batched_rotary_embedding_multi_lora( ref_key, atol=get_default_atol(out_key), rtol=get_default_rtol(out_key)) + + +@torch.inference_mode() +def test_rope_module_cache(): + MAX_POSITIONS = [123, 1234] + BASES = [10000, 1000000] + ROPE_SCALINGS = [ + None, { + "type": "linear", + "factor": (1, ) + }, { + "type": "dynamic", + "factor": 1 + } + ] + settings = [ + HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE, + ROPE_SCALINGS, DTYPES + ] + rope_setting_id_map = {} + for setting in product(*settings): + head_size, rotary_dim, max_position, base, \ + is_neox_stype, rope_scaling, dtype = setting + if rotary_dim is None: + rotary_dim = head_size + rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_stype, rope_scaling, dtype) + # different settings cannot share the same rope module + assert id(rope) not in rope_setting_id_map.values() + assert all(x.dtype == dtype for x in rope.buffers()) + assert all(x.dtype == dtype for x in rope.parameters()) + rope_setting_id_map[str(setting)] = id(rope) + + for setting in product(*settings): + head_size, rotary_dim, max_position, base, \ + is_neox_stype, rope_scaling, dtype = setting + if rotary_dim is None: + rotary_dim = head_size + rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_stype, rope_scaling, dtype) + # check if cache take effect + assert id(rope) == rope_setting_id_map[str(setting)] diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index f41e0f30..4758ca96 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -53,6 +53,7 @@ class RotaryEmbedding(nn.Module): max_position_embeddings: int, base: int, is_neox_style: bool, + dtype: torch.dtype, ) -> None: super().__init__() self.head_size = head_size @@ -62,7 +63,7 @@ class RotaryEmbedding(nn.Module): self.is_neox_style = is_neox_style cache = self._compute_cos_sin_cache() - cache = cache.to(torch.get_default_dtype()) + cache = cache.to(dtype) self.register_buffer("cos_sin_cache", cache, persistent=False) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: @@ -178,12 +179,13 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): base: int, is_neox_style: bool, scaling_factors: Union[List[float], float], + dtype: torch.dtype, ) -> None: if isinstance(scaling_factors, float): scaling_factors = [scaling_factors] self.scaling_factors = scaling_factors super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + is_neox_style, dtype) def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.base) @@ -219,10 +221,11 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): base: int, is_neox_style: bool, scaling_factor: float, + dtype: torch.dtype, ) -> None: self.scaling_factor = scaling_factor super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + is_neox_style, dtype) def _compute_cos_sin_cache(self) -> torch.Tensor: # NOTE(woosuk): self.max_position_embeddings is the original @@ -299,6 +302,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): base: int, is_neox_style: bool, scaling_factor: float, + dtype: torch.dtype, *, extrapolation_factor: float = 1, attn_factor: float = 1, @@ -314,7 +318,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): self.mscale = float( _yarn_get_mscale(self.scaling_factor) * attn_factor) super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + is_neox_style, dtype) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: pos_freqs = self.base**( @@ -359,6 +363,7 @@ class Phi3SuScaledRotaryEmbedding(nn.Module): original_max_position_embeddings: int, base: int, is_neox_style: bool, + dtype: torch.dtype, short_factor: List[float], long_factor: List[float], short_mscale: float = 1.1, @@ -385,14 +390,14 @@ class Phi3SuScaledRotaryEmbedding(nn.Module): short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale) - short_cache = short_cache.to(torch.get_default_dtype()) + short_cache = short_cache.to(dtype) self.register_buffer("short_cos_sin_cache", short_cache, persistent=False) long_cache = self._compute_cos_sin_cache(max_position_embeddings, long_factor, long_mscale) - long_cache = long_cache.to(torch.get_default_dtype()) + long_cache = long_cache.to(dtype) self.register_buffer("long_cos_sin_cache", long_cache, persistent=False) @@ -463,7 +468,10 @@ def get_rope( base: int, is_neox_style: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, ) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() if rope_scaling is not None: # Transforms every value that is a list into a tuple for caching calls rope_scaling_tuple = { @@ -474,12 +482,12 @@ def get_rope( else: rope_scaling_args = None key = (head_size, rotary_dim, max_position, base, is_neox_style, - rope_scaling_args) + rope_scaling_args, dtype) if key in _ROPE_DICT: return _ROPE_DICT[key] if rope_scaling is None: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style) + is_neox_style, dtype) else: scaling_type = rope_scaling["type"] if scaling_type != "su": @@ -488,11 +496,11 @@ def get_rope( rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, - scaling_factor) + scaling_factor, dtype) elif scaling_type == "dynamic": rotary_emb = DynamicNTKScalingRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, - scaling_factor) + scaling_factor, dtype) elif scaling_type == "yarn": original_max_position = rope_scaling[ "original_max_position_embeddings"] @@ -505,7 +513,7 @@ def get_rope( rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, original_max_position, base, is_neox_style, - scaling_factor, + scaling_factor, dtype, **extra_kwargs) elif scaling_type == "su": short_factor = rope_scaling["short_factor"] @@ -519,7 +527,8 @@ def get_rope( } rotary_emb = Phi3SuScaledRotaryEmbedding( head_size, rotary_dim, max_position, original_max_position, - base, is_neox_style, short_factor, long_factor, **extra_kwargs) + base, is_neox_style, dtype, short_factor, long_factor, + **extra_kwargs) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb