2023-03-31 09:51:22 -07:00
|
|
|
#include <torch/extension.h>
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
|
|
|
2023-09-02 14:59:47 +09:00
|
|
|
#include "dispatch_utils.h"
|
2023-05-03 13:40:13 -07:00
|
|
|
#include "reduction_utils.cuh"
|
2023-03-31 09:51:22 -07:00
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
namespace vllm {
|
2023-03-31 09:51:22 -07:00
|
|
|
|
|
|
|
// TODO(woosuk): Further optimize this kernel.
|
|
|
|
template<typename scalar_t>
|
|
|
|
__global__ void rms_norm_kernel(
|
2023-10-16 17:48:42 -07:00
|
|
|
scalar_t* __restrict__ out, // [..., hidden_size]
|
|
|
|
const scalar_t* __restrict__ input, // [..., hidden_size]
|
2023-03-31 09:51:22 -07:00
|
|
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
|
|
|
const float epsilon,
|
|
|
|
const int num_tokens,
|
|
|
|
const int hidden_size) {
|
|
|
|
__shared__ float s_variance;
|
|
|
|
float variance = 0.0f;
|
|
|
|
|
|
|
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
|
|
|
const float x = (float) input[blockIdx.x * hidden_size + idx];
|
|
|
|
variance += x * x;
|
|
|
|
}
|
|
|
|
variance = blockReduceSum<float>(variance);
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
|
|
|
float x = (float) input[blockIdx.x * hidden_size + idx];
|
|
|
|
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
} // namespace vllm
|
2023-03-31 09:51:22 -07:00
|
|
|
|
|
|
|
void rms_norm(
|
2023-10-16 17:48:42 -07:00
|
|
|
torch::Tensor& out, // [..., hidden_size]
|
|
|
|
torch::Tensor& input, // [..., hidden_size]
|
2023-03-31 09:51:22 -07:00
|
|
|
torch::Tensor& weight, // [hidden_size]
|
|
|
|
float epsilon) {
|
2023-10-16 17:48:42 -07:00
|
|
|
int hidden_size = input.size(-1);
|
|
|
|
int num_tokens = input.numel() / hidden_size;
|
2023-03-31 09:51:22 -07:00
|
|
|
|
|
|
|
dim3 grid(num_tokens);
|
|
|
|
dim3 block(std::min(hidden_size, 1024));
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
2023-09-02 14:59:47 +09:00
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
2023-03-31 09:51:22 -07:00
|
|
|
input.scalar_type(),
|
|
|
|
"rms_norm_kernel",
|
|
|
|
[&] {
|
2023-06-17 03:07:40 -07:00
|
|
|
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
2023-03-31 09:51:22 -07:00
|
|
|
out.data_ptr<scalar_t>(),
|
|
|
|
input.data_ptr<scalar_t>(),
|
|
|
|
weight.data_ptr<scalar_t>(),
|
|
|
|
epsilon,
|
|
|
|
num_tokens,
|
|
|
|
hidden_size);
|
|
|
|
});
|
|
|
|
}
|