[Attention] Remove slow setattr in MLA (#14769)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-03-13 17:31:14 -04:00 committed by GitHub
parent 02fcaa3d0a
commit d47807ba08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -161,8 +161,13 @@ class RotaryEmbedding(CustomOp):
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
# is expensive, so avoid calling it if possible
if self.cos_sin_cache.device != query.device or \
self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None: