[Kernel] add kernel for FATReLU (#9610)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
8a02cd045a
commit
295a061fb3
@ -89,6 +89,48 @@ void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
|
||||||
|
const float f = (float)x;
|
||||||
|
return (T)(f > threshold ? f : 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
|
||||||
|
__global__ void act_and_mul_kernel_with_param(
|
||||||
|
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
|
||||||
|
const float param) {
|
||||||
|
const int64_t token_idx = blockIdx.x;
|
||||||
|
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
|
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||||
|
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||||
|
out[token_idx * d + idx] = ACT_FN(x, param) * y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
|
||||||
|
int d = input.size(-1) / 2; \
|
||||||
|
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||||
|
dim3 grid(num_tokens); \
|
||||||
|
dim3 block(std::min(d, 1024)); \
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
|
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
|
||||||
|
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
|
||||||
|
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||||
|
input.data_ptr<scalar_t>(), d, \
|
||||||
|
PARAM); \
|
||||||
|
});
|
||||||
|
|
||||||
|
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
|
||||||
|
torch::Tensor& input, // [..., 2 * d]
|
||||||
|
double threshold) {
|
||||||
|
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
|
||||||
|
}
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
// Element-wise activation kernel template.
|
// Element-wise activation kernel template.
|
||||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||||
__global__ void activation_kernel(
|
__global__ void activation_kernel(
|
||||||
|
@ -48,6 +48,9 @@ void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
|||||||
|
|
||||||
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
|
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
|
||||||
|
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
|
||||||
|
double threshold);
|
||||||
|
|
||||||
void gelu_new(torch::Tensor& out, torch::Tensor& input);
|
void gelu_new(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
|
||||||
void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
@ -60,6 +60,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
|
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
|
||||||
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
||||||
|
|
||||||
|
// FATReLU implementation.
|
||||||
|
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
|
||||||
|
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);
|
||||||
|
|
||||||
// GELU implementation used in GPT-2.
|
// GELU implementation used in GPT-2.
|
||||||
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
|
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
|
||||||
ops.impl("gelu_new", torch::kCUDA, &gelu_new);
|
ops.impl("gelu_new", torch::kCUDA, &gelu_new);
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
|
import random
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.kernels.utils import opcheck
|
from tests.kernels.utils import opcheck
|
||||||
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
|
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
|
||||||
NewGELU, QuickGELU,
|
GeluAndMul, NewGELU,
|
||||||
SiluAndMul)
|
QuickGELU, SiluAndMul)
|
||||||
from vllm.utils import seed_everything
|
from vllm.utils import seed_everything
|
||||||
|
|
||||||
from .allclose_default import get_default_atol, get_default_rtol
|
from .allclose_default import get_default_atol, get_default_rtol
|
||||||
@ -20,7 +21,8 @@ CUDA_DEVICES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"])
|
@pytest.mark.parametrize("activation",
|
||||||
|
["silu", "gelu", "gelu_tanh", "fatrelu"])
|
||||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
@pytest.mark.parametrize("d", D)
|
@pytest.mark.parametrize("d", D)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@ -47,16 +49,23 @@ def test_act_and_mul(
|
|||||||
elif activation == "gelu_tanh":
|
elif activation == "gelu_tanh":
|
||||||
layer = GeluAndMul(approximate="tanh")
|
layer = GeluAndMul(approximate="tanh")
|
||||||
fn = torch.ops._C.gelu_tanh_and_mul
|
fn = torch.ops._C.gelu_tanh_and_mul
|
||||||
|
elif activation == "fatrelu":
|
||||||
|
threshold = random.uniform(0, 1)
|
||||||
|
layer = FatreluAndMul(threshold)
|
||||||
|
fn = torch.ops._C.fatrelu_and_mul
|
||||||
out = layer(x)
|
out = layer(x)
|
||||||
ref_out = layer.forward_native(x)
|
ref_out = layer.forward_native(x)
|
||||||
# The SiLU and GELU implementations are equivalent to the native PyTorch
|
# The SiLU, GELU and FatReLU implementations are equivalent to the native
|
||||||
# implementations, so we can do exact comparison.
|
# PyTorch implementations, so we can do exact comparison.
|
||||||
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
|
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
|
||||||
|
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
output_shape = (x.shape[:-1] + (d, ))
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
opcheck(fn, (out, x))
|
if activation == "fatrelu":
|
||||||
|
opcheck(fn, (out, x, threshold))
|
||||||
|
else:
|
||||||
|
opcheck(fn, (out, x))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
|
@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
|
||||||
|
@ -79,6 +79,12 @@ def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
|||||||
torch.ops._C.gelu_tanh_and_mul(out, x)
|
torch.ops._C.gelu_tanh_and_mul(out, x)
|
||||||
|
|
||||||
|
|
||||||
|
def fatrelu_and_mul(out: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
threshold: float = 0.0) -> None:
|
||||||
|
torch.ops._C.fatrelu_and_mul(out, x, threshold)
|
||||||
|
|
||||||
|
|
||||||
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
torch.ops._C.gelu_fast(out, x)
|
torch.ops._C.gelu_fast(out, x)
|
||||||
|
|
||||||
|
@ -39,7 +39,13 @@ class FatreluAndMul(CustomOp):
|
|||||||
return x1 * x2
|
return x1 * x2
|
||||||
|
|
||||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.forward_native(x)
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
|
ops.fatrelu_and_mul(out, x, self.threshold)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("silu_and_mul")
|
@CustomOp.register("silu_and_mul")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user