[Bugfix][Kernel][CPU] Fix num_tokens in CPU rotary embedding kernel (#14667)

Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
This commit is contained in:
Thien Tran 2025-03-14 14:47:49 +08:00 committed by GitHub
parent 9532c49836
commit 27b50f1fe6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -170,7 +170,7 @@ void rotary_embedding_gptj_impl(
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1);
int num_tokens = positions.numel();
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;