Add kernel for GeGLU with approximate GELU (#3337)
This commit is contained in:
parent
49a3c8662b
commit
602358f8a8
@ -33,12 +33,25 @@ template<typename T>
|
|||||||
__device__ __forceinline__ T gelu_kernel(const T& x) {
|
__device__ __forceinline__ T gelu_kernel(const T& x) {
|
||||||
// Equivalent to PyTorch GELU with 'none' approximation.
|
// Equivalent to PyTorch GELU with 'none' approximation.
|
||||||
// Refer to:
|
// Refer to:
|
||||||
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
|
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
|
||||||
const float f = (float) x;
|
const float f = (float) x;
|
||||||
constexpr float ALPHA = M_SQRT1_2;
|
constexpr float ALPHA = M_SQRT1_2;
|
||||||
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
|
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
||||||
|
// Equivalent to PyTorch GELU with 'tanh' approximation.
|
||||||
|
// Refer to:
|
||||||
|
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
|
||||||
|
const float f = (float) x;
|
||||||
|
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
|
||||||
|
constexpr float KAPPA = 0.044715;
|
||||||
|
float x_cube = f * f * f;
|
||||||
|
float inner = BETA * (f + KAPPA * x_cube);
|
||||||
|
return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
// Launch activation and gating kernel.
|
// Launch activation and gating kernel.
|
||||||
@ -73,6 +86,13 @@ void gelu_and_mul(
|
|||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void gelu_tanh_and_mul(
|
||||||
|
torch::Tensor& out, // [..., d]
|
||||||
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
|
{
|
||||||
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
||||||
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// Element-wise activation kernel template.
|
// Element-wise activation kernel template.
|
||||||
|
@ -61,6 +61,10 @@ void gelu_and_mul(
|
|||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
torch::Tensor& input);
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
void gelu_tanh_and_mul(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
void gelu_new(
|
void gelu_new(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
torch::Tensor& input);
|
torch::Tensor& input);
|
||||||
|
@ -25,7 +25,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
ops.def(
|
ops.def(
|
||||||
"gelu_and_mul",
|
"gelu_and_mul",
|
||||||
&gelu_and_mul,
|
&gelu_and_mul,
|
||||||
"Activation function used in GeGLU.");
|
"Activation function used in GeGLU with `none` approximation.");
|
||||||
|
ops.def(
|
||||||
|
"gelu_tanh_and_mul",
|
||||||
|
&gelu_tanh_and_mul,
|
||||||
|
"Activation function used in GeGLU with `tanh` approximation.");
|
||||||
ops.def(
|
ops.def(
|
||||||
"gelu_new",
|
"gelu_new",
|
||||||
&gelu_new,
|
&gelu_new,
|
||||||
|
@ -16,7 +16,7 @@ CUDA_DEVICES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("activation", [SiluAndMul, GeluAndMul])
|
@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"])
|
||||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
@pytest.mark.parametrize("d", D)
|
@pytest.mark.parametrize("d", D)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@ -24,7 +24,7 @@ CUDA_DEVICES = [
|
|||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_act_and_mul(
|
def test_act_and_mul(
|
||||||
activation: Type[torch.nn.Module],
|
activation: str,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
d: int,
|
d: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
@ -36,7 +36,12 @@ def test_act_and_mul(
|
|||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
||||||
layer = activation()
|
if activation == "silu":
|
||||||
|
layer = SiluAndMul()
|
||||||
|
elif activation == "gelu":
|
||||||
|
layer = GeluAndMul(approximate="none")
|
||||||
|
elif activation == "gelu_tanh":
|
||||||
|
layer = GeluAndMul(approximate="tanh")
|
||||||
out = layer(x)
|
out = layer(x)
|
||||||
ref_out = layer._forward(x)
|
ref_out = layer._forward(x)
|
||||||
# The SiLU and GELU implementations are equivalent to the native PyTorch
|
# The SiLU and GELU implementations are equivalent to the native PyTorch
|
||||||
|
@ -47,16 +47,25 @@ class GeluAndMul(nn.Module):
|
|||||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, approximate: str = "none"):
|
||||||
|
super().__init__()
|
||||||
|
self.approximate = approximate
|
||||||
|
if approximate not in ("none", "tanh"):
|
||||||
|
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||||
|
|
||||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
return F.gelu(x[..., :d]) * x[..., d:]
|
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
output_shape = (x.shape[:-1] + (d, ))
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
|
if self.approximate == "none":
|
||||||
ops.gelu_and_mul(out, x)
|
ops.gelu_and_mul(out, x)
|
||||||
|
elif self.approximate == "tanh":
|
||||||
|
ops.gelu_tanh_and_mul(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user