From d47807ba0806c5bbd8fd08c19013c327b34dcac5 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 13 Mar 2025 17:31:14 -0400 Subject: [PATCH] [Attention] Remove slow setattr in MLA (#14769) Signed-off-by: Lucas Wilkinson --- vllm/model_executor/layers/rotary_embedding.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index d4b8cf25..fd27775b 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -161,8 +161,13 @@ class RotaryEmbedding(CustomOp): ) -> Tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops - self.cos_sin_cache = self.cos_sin_cache.to(query.device, - dtype=query.dtype) + # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) + # is expensive, so avoid calling it if possible + if self.cos_sin_cache.device != query.device or \ + self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.device, + dtype=query.dtype) + # ops.rotary_embedding()/batched_rotary_embedding() # are in-place operations that update the query and key tensors. if offsets is not None: