#include #include #include "dispatch_utils.h" namespace vllm { template __global__ void rotary_embedding_neox_kernel( const int64_t* __restrict__ positions, // [num_tokens] scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const int rot_dim, const int query_stride, const int key_stride, const int num_heads, const int num_kv_heads, const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const int embed_dim = rot_dim / 2; const int nq = num_heads * embed_dim; for (int i = threadIdx.x; i < nq; i += blockDim.x) { const int head_idx = i / embed_dim; const int token_head = token_idx * query_stride + head_idx * head_size; const int rot_offset = i % embed_dim; const int x_index = rot_offset; const int y_index = embed_dim + rot_offset; const int out_x = token_idx * query_stride + head_idx * head_size + x_index; const int out_y = token_idx * query_stride + head_idx * head_size + y_index; const scalar_t cos = __ldg(cache_ptr + x_index); const scalar_t sin = __ldg(cache_ptr + y_index); const scalar_t q_x = query[token_head + x_index]; const scalar_t q_y = query[token_head + y_index]; query[out_x] = q_x * cos - q_y * sin; query[out_y] = q_y * cos + q_x * sin; } const int nk = num_kv_heads * embed_dim; for (int i = threadIdx.x; i < nk; i += blockDim.x) { const int head_idx = i / embed_dim; const int token_head = token_idx * key_stride + head_idx * head_size; const int rot_offset = i % embed_dim; const int x_index = rot_offset; const int y_index = embed_dim + rot_offset; const int out_x = token_idx * key_stride + head_idx * head_size + x_index; const int out_y = token_idx * key_stride + head_idx * head_size + y_index; const scalar_t cos = __ldg(cache_ptr + x_index); const scalar_t sin = __ldg(cache_ptr + y_index); const scalar_t k_x = key[token_head + x_index]; const scalar_t k_y = key[token_head + y_index]; key[out_x] = k_x * cos - k_y * sin; key[out_y] = k_y * cos + k_x * sin; } } } // namespace vllm void rotary_embedding_neox( torch::Tensor& positions, // [num_tokens] torch::Tensor& query, // [num_tokens, num_heads * head_size] torch::Tensor& key, // [num_tokens, num_kv_heads * head_size] int head_size, torch::Tensor& cos_sin_cache) // [max_position, rot_dim] { int num_tokens = query.size(0); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(1) / head_size; int num_kv_heads = key.size(1) / head_size; int query_stride = query.stride(0); int key_stride = key.stride(0); dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( query.scalar_type(), "rotary_embedding_neox", [&] { vllm::rotary_embedding_neox_kernel<<>>( positions.data_ptr(), query.data_ptr(), key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, query_stride, key_stride, num_heads, num_kv_heads, head_size); }); }