vllm/csrc/pos_encoding.cpp
2023-04-02 00:30:17 -07:00

15 lines
340 B
C++

#include <torch/extension.h>
void rotary_embedding_neox(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
torch::Tensor& cos_sin_cache);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rotary_embedding_neox",
&rotary_embedding_neox,
"Apply GPT-NeoX style rotary embedding to query and key");
}