From 27b50f1fe6e325f73b405263c3ac1fc668531118 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 14 Mar 2025 14:47:49 +0800 Subject: [PATCH] [Bugfix][Kernel][CPU] Fix num_tokens in CPU rotary embedding kernel (#14667) Signed-off-by: Thien Tran --- csrc/cpu/pos_encoding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index 96bce7dd..8a59e884 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -170,7 +170,7 @@ void rotary_embedding_gptj_impl( void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox) { - int num_tokens = query.numel() / query.size(-1); + int num_tokens = positions.numel(); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; int num_kv_heads = key.size(-1) / head_size;