2023-03-31 09:51:22 -07:00
|
|
|
#include <torch/extension.h>
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
|
|
|
2023-05-03 13:40:13 -07:00
|
|
|
#include "reduction_utils.cuh"
|
2023-03-31 09:51:22 -07:00
|
|
|
|
|
|
|
namespace cacheflow {
|
|
|
|
|
|
|
|
// TODO(woosuk): Further optimize this kernel.
|
|
|
|
template<typename scalar_t>
|
|
|
|
__global__ void rms_norm_kernel(
|
|
|
|
scalar_t* __restrict__ out, // [num_tokens, hidden_size]
|
|
|
|
const scalar_t* __restrict__ input, // [num_tokens, hidden_size]
|
|
|
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
|
|
|
const float epsilon,
|
|
|
|
const int num_tokens,
|
|
|
|
const int hidden_size) {
|
|
|
|
__shared__ float s_variance;
|
|
|
|
float variance = 0.0f;
|
|
|
|
|
|
|
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
|
|
|
const float x = (float) input[blockIdx.x * hidden_size + idx];
|
|
|
|
variance += x * x;
|
|
|
|
}
|
|
|
|
variance = blockReduceSum<float>(variance);
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
|
|
|
float x = (float) input[blockIdx.x * hidden_size + idx];
|
|
|
|
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace cacheflow
|
|
|
|
|
|
|
|
void rms_norm(
|
|
|
|
torch::Tensor& out, // [num_tokens, hidden_size]
|
|
|
|
torch::Tensor& input, // [num_tokens, hidden_size]
|
|
|
|
torch::Tensor& weight, // [hidden_size]
|
|
|
|
float epsilon) {
|
|
|
|
int num_tokens = input.size(0);
|
|
|
|
int hidden_size = input.size(1);
|
|
|
|
|
|
|
|
dim3 grid(num_tokens);
|
|
|
|
dim3 block(std::min(hidden_size, 1024));
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
2023-05-03 14:09:44 -07:00
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND2(
|
|
|
|
at::ScalarType::Half,
|
|
|
|
at::ScalarType::BFloat16,
|
2023-03-31 09:51:22 -07:00
|
|
|
input.scalar_type(),
|
|
|
|
"rms_norm_kernel",
|
|
|
|
[&] {
|
|
|
|
cacheflow::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
|
|
|
out.data_ptr<scalar_t>(),
|
|
|
|
input.data_ptr<scalar_t>(),
|
|
|
|
weight.data_ptr<scalar_t>(),
|
|
|
|
epsilon,
|
|
|
|
num_tokens,
|
|
|
|
hidden_size);
|
|
|
|
});
|
|
|
|
}
|