118 lines
4.0 KiB
C++
118 lines
4.0 KiB
C++
#include "cpu_types.hpp"
|
|
|
|
namespace {
|
|
template <typename scalar_t>
|
|
void rms_norm_impl(scalar_t* __restrict__ out,
|
|
const scalar_t* __restrict__ input,
|
|
const scalar_t* __restrict__ weight, const float epsilon,
|
|
const int num_tokens, const int hidden_size) {
|
|
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
|
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
|
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
|
|
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < num_tokens; ++i) {
|
|
vec_op::FP32Vec8 variance(0.0);
|
|
auto input_p = input + i * hidden_size;
|
|
auto output_p = out + i * hidden_size;
|
|
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
|
|
scalar_vec_t x(input_p + j);
|
|
vec_op::FP32Vec8 fp32_x(x);
|
|
variance = variance + fp32_x * fp32_x;
|
|
}
|
|
|
|
float s_variance =
|
|
1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
|
|
vec_op::FP32Vec8 fp32_s_variance(s_variance);
|
|
|
|
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
|
|
scalar_vec_t x(input_p + j);
|
|
scalar_vec_t w(weight + j);
|
|
|
|
vec_op::FP32Vec8 fp32_x(x);
|
|
vec_op::FP32Vec8 fp32_w(w);
|
|
|
|
vec_op::FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w;
|
|
|
|
scalar_vec_t out(fp32_out);
|
|
out.save(output_p + j);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename scalar_t>
|
|
void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
|
|
scalar_t* __restrict__ residual,
|
|
const scalar_t* __restrict__ weight,
|
|
const float epsilon, const int num_tokens,
|
|
const int hidden_size) {
|
|
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
|
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
|
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
|
|
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < num_tokens; ++i) {
|
|
vec_op::FP32Vec8 variance(0.0);
|
|
auto input_p = input + i * hidden_size;
|
|
auto residual_p = residual + i * hidden_size;
|
|
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
|
|
scalar_vec_t x(input_p + j);
|
|
scalar_vec_t res(residual_p + j);
|
|
vec_op::FP32Vec8 fp32_x(x);
|
|
vec_op::FP32Vec8 fp32_res(res);
|
|
|
|
fp32_x = fp32_x + fp32_res;
|
|
variance = variance + fp32_x * fp32_x;
|
|
scalar_vec_t out(fp32_x);
|
|
out.save(residual_p + j);
|
|
}
|
|
|
|
float s_variance =
|
|
1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
|
|
vec_op::FP32Vec8 fp32_s_variance(s_variance);
|
|
|
|
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
|
|
scalar_vec_t w(weight + j);
|
|
scalar_vec_t res(residual_p + j);
|
|
|
|
vec_op::FP32Vec8 fp32_w(w);
|
|
vec_op::FP32Vec8 fp32_res(res);
|
|
|
|
vec_op::FP32Vec8 fp32_out = fp32_res * fp32_s_variance * fp32_w;
|
|
|
|
scalar_vec_t out(fp32_out);
|
|
out.save(input_p + j);
|
|
}
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
|
float epsilon) {
|
|
int hidden_size = input.size(-1);
|
|
int num_tokens = input.numel() / hidden_size;
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
|
|
CPU_KERNEL_GUARD_IN(rms_norm_impl)
|
|
rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
|
weight.data_ptr<scalar_t>(), epsilon, num_tokens,
|
|
hidden_size);
|
|
CPU_KERNEL_GUARD_OUT(rms_norm_impl)
|
|
});
|
|
}
|
|
|
|
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
|
torch::Tensor& weight, float epsilon) {
|
|
int hidden_size = input.size(-1);
|
|
int num_tokens = input.numel() / hidden_size;
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
input.scalar_type(), "fused_add_rms_norm_impl", [&] {
|
|
CPU_KERNEL_GUARD_IN(fused_add_rms_norm_impl)
|
|
fused_add_rms_norm_impl(
|
|
input.data_ptr<scalar_t>(), residual.data_ptr<scalar_t>(),
|
|
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
|
|
CPU_KERNEL_GUARD_OUT(fused_add_rms_norm_impl)
|
|
});
|
|
}
|