Add kernel for GeGLU with approximate GELU (#3337)
This commit is contained in:
parent
49a3c8662b
commit
602358f8a8
@ -33,12 +33,25 @@ template<typename T>
|
||||
__device__ __forceinline__ T gelu_kernel(const T& x) {
|
||||
// Equivalent to PyTorch GELU with 'none' approximation.
|
||||
// Refer to:
|
||||
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
|
||||
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
|
||||
const float f = (float) x;
|
||||
constexpr float ALPHA = M_SQRT1_2;
|
||||
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
||||
// Equivalent to PyTorch GELU with 'tanh' approximation.
|
||||
// Refer to:
|
||||
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
|
||||
const float f = (float) x;
|
||||
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
|
||||
constexpr float KAPPA = 0.044715;
|
||||
float x_cube = f * f * f;
|
||||
float inner = BETA * (f + KAPPA * x_cube);
|
||||
return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// Launch activation and gating kernel.
|
||||
@ -73,6 +86,13 @@ void gelu_and_mul(
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
||||
}
|
||||
|
||||
void gelu_tanh_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Element-wise activation kernel template.
|
||||
|
@ -61,6 +61,10 @@ void gelu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_tanh_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
@ -25,7 +25,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
ops.def(
|
||||
"gelu_and_mul",
|
||||
&gelu_and_mul,
|
||||
"Activation function used in GeGLU.");
|
||||
"Activation function used in GeGLU with `none` approximation.");
|
||||
ops.def(
|
||||
"gelu_tanh_and_mul",
|
||||
&gelu_tanh_and_mul,
|
||||
"Activation function used in GeGLU with `tanh` approximation.");
|
||||
ops.def(
|
||||
"gelu_new",
|
||||
&gelu_new,
|
||||
|
@ -16,7 +16,7 @@ CUDA_DEVICES = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("activation", [SiluAndMul, GeluAndMul])
|
||||
@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"])
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@ -24,7 +24,7 @@ CUDA_DEVICES = [
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_act_and_mul(
|
||||
activation: Type[torch.nn.Module],
|
||||
activation: str,
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
@ -36,7 +36,12 @@ def test_act_and_mul(
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
||||
layer = activation()
|
||||
if activation == "silu":
|
||||
layer = SiluAndMul()
|
||||
elif activation == "gelu":
|
||||
layer = GeluAndMul(approximate="none")
|
||||
elif activation == "gelu_tanh":
|
||||
layer = GeluAndMul(approximate="tanh")
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
# The SiLU and GELU implementations are equivalent to the native PyTorch
|
||||
|
@ -47,16 +47,25 @@ class GeluAndMul(nn.Module):
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
"""
|
||||
|
||||
def __init__(self, approximate: str = "none"):
|
||||
super().__init__()
|
||||
self.approximate = approximate
|
||||
if approximate not in ("none", "tanh"):
|
||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d]) * x[..., d:]
|
||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
if self.approximate == "none":
|
||||
ops.gelu_and_mul(out, x)
|
||||
elif self.approximate == "tanh":
|
||||
ops.gelu_tanh_and_mul(out, x)
|
||||
return out
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user