[Kernel][CPU] Add Quick gelu
to CPU (#5717)
This commit is contained in:
parent
d9a252bc8e
commit
bd620b01fb
@ -59,6 +59,13 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
|
||||
return w3 * x * (ones + t);
|
||||
}
|
||||
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) {
|
||||
const vec_op::FP32Vec8 zeros(0.0);
|
||||
const vec_op::FP32Vec8 ones(1.0);
|
||||
const vec_op::FP32Vec8 w1(1.702f);
|
||||
return x / (ones + (zeros - w1 * x).exp());
|
||||
}
|
||||
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
|
||||
const vec_op::FP32Vec8 ones(1.0);
|
||||
const vec_op::FP32Vec8 w1(M_SQRT1_2);
|
||||
@ -142,3 +149,15 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
|
||||
CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_quick(torch::Tensor& out, torch::Tensor& input) {
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_quick_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(gelu_quick_impl)
|
||||
activation_kernel<scalar_t, gelu_quick_act, false>(
|
||||
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||
CPU_KERNEL_GUARD_OUT(gelu_quick_impl)
|
||||
});
|
||||
}
|
||||
|
@ -58,6 +58,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("gelu_fast", torch::kCPU, &gelu_fast);
|
||||
|
||||
// Quick GELU implementation.
|
||||
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("gelu_quick", torch::kCPU, &gelu_quick);
|
||||
|
||||
// Layernorm
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
ops.def(
|
||||
|
@ -43,6 +43,9 @@ class ipex_ops:
|
||||
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
out.copy_(torch.nn.functional.gelu(x))
|
||||
|
||||
# TODO add implementation of gelu_quick here
|
||||
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
|
||||
def paged_attention_v1(
|
||||
out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
|
@ -155,6 +155,9 @@ class QuickGELU(CustomOp):
|
||||
ops.gelu_quick(out, x)
|
||||
return out
|
||||
|
||||
# TODO implement forward_xpu for QuickGELU
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
|
Loading…
x
Reference in New Issue
Block a user