Add kernel for GeGLU with approximate GELU (#3337)

This commit is contained in:
Woosuk Kwon 2024-03-12 22:06:17 -07:00 committed by GitHub
parent 49a3c8662b
commit 602358f8a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 49 additions and 7 deletions

View File

@ -33,12 +33,25 @@ template<typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
const float f = (float) x;
constexpr float ALPHA = M_SQRT1_2;
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
}
template<typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const float f = (float) x;
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
constexpr float KAPPA = 0.044715;
float x_cube = f * f * f;
float inner = BETA * (f + KAPPA * x_cube);
return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
}
} // namespace vllm
// Launch activation and gating kernel.
@ -73,6 +86,13 @@ void gelu_and_mul(
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
}
void gelu_tanh_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
}
namespace vllm {
// Element-wise activation kernel template.

View File

@ -61,6 +61,10 @@ void gelu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_tanh_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_new(
torch::Tensor& out,
torch::Tensor& input);

View File

@ -25,7 +25,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU.");
"Activation function used in GeGLU with `none` approximation.");
ops.def(
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def(
"gelu_new",
&gelu_new,

View File

@ -16,7 +16,7 @@ CUDA_DEVICES = [
]
@pytest.mark.parametrize("activation", [SiluAndMul, GeluAndMul])
@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@ -24,7 +24,7 @@ CUDA_DEVICES = [
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_act_and_mul(
activation: Type[torch.nn.Module],
activation: str,
num_tokens: int,
d: int,
dtype: torch.dtype,
@ -36,7 +36,12 @@ def test_act_and_mul(
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
layer = activation()
if activation == "silu":
layer = SiluAndMul()
elif activation == "gelu":
layer = GeluAndMul(approximate="none")
elif activation == "gelu_tanh":
layer = GeluAndMul(approximate="tanh")
out = layer(x)
ref_out = layer._forward(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch

View File

@ -47,16 +47,25 @@ class GeluAndMul(nn.Module):
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
def __init__(self, approximate: str = "none"):
super().__init__()
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d]) * x[..., d:]
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.gelu_and_mul(out, x)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
return out