[Kernel] Support MulAndSilu (#11624)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
a3a3ee4e6f
commit
42f5e7c52a
@ -9,8 +9,16 @@
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
|
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
||||||
|
bool act_first>
|
||||||
|
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
|
||||||
|
const scalar_t& y) {
|
||||||
|
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
|
||||||
|
}
|
||||||
// Activation and gating kernel template.
|
// Activation and gating kernel template.
|
||||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
|
||||||
|
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
||||||
|
bool act_first>
|
||||||
__global__ void act_and_mul_kernel(
|
__global__ void act_and_mul_kernel(
|
||||||
scalar_t* __restrict__ out, // [..., d]
|
scalar_t* __restrict__ out, // [..., d]
|
||||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||||
@ -19,7 +27,7 @@ __global__ void act_and_mul_kernel(
|
|||||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||||
out[token_idx * d + idx] = ACT_FN(x) * y;
|
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,7 +63,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
// Launch activation and gating kernel.
|
// Launch activation and gating kernel.
|
||||||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
// Use ACT_FIRST (bool) indicating whether to apply the activation function
|
||||||
|
// first.
|
||||||
|
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
|
||||||
int d = input.size(-1) / 2; \
|
int d = input.size(-1) / 2; \
|
||||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||||
dim3 grid(num_tokens); \
|
dim3 grid(num_tokens); \
|
||||||
@ -64,7 +74,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
|||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
||||||
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
|
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \
|
||||||
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||||
input.data_ptr<scalar_t>(), d); \
|
input.data_ptr<scalar_t>(), d); \
|
||||||
});
|
});
|
||||||
@ -72,19 +82,27 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
|||||||
void silu_and_mul(torch::Tensor& out, // [..., d]
|
void silu_and_mul(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& input) // [..., 2 * d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void mul_and_silu(torch::Tensor& out, // [..., d]
|
||||||
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
|
{
|
||||||
|
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu
|
||||||
|
// applies the silu to the latter half of the input.
|
||||||
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_and_mul(torch::Tensor& out, // [..., d]
|
void gelu_and_mul(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& input) // [..., 2 * d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& input) // [..., 2 * d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
@ -86,6 +86,8 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
|||||||
|
|
||||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
|
||||||
|
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
|
||||||
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
|
||||||
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
|
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
@ -55,6 +55,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||||
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
||||||
|
|
||||||
|
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
|
||||||
|
|
||||||
// Activation function used in GeGLU with `none` approximation.
|
// Activation function used in GeGLU with `none` approximation.
|
||||||
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
|
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||||
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
||||||
|
@ -6,8 +6,9 @@ import torch
|
|||||||
|
|
||||||
from tests.kernels.utils import opcheck
|
from tests.kernels.utils import opcheck
|
||||||
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
|
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
|
||||||
GeluAndMul, NewGELU,
|
GeluAndMul, MulAndSilu,
|
||||||
QuickGELU, SiluAndMul)
|
NewGELU, QuickGELU,
|
||||||
|
SiluAndMul)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .allclose_default import get_default_atol, get_default_rtol
|
from .allclose_default import get_default_atol, get_default_rtol
|
||||||
@ -21,8 +22,9 @@ CUDA_DEVICES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("activation",
|
@pytest.mark.parametrize(
|
||||||
["silu", "gelu", "gelu_tanh", "fatrelu"])
|
"activation",
|
||||||
|
["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"])
|
||||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
@pytest.mark.parametrize("d", D)
|
@pytest.mark.parametrize("d", D)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@ -40,9 +42,12 @@ def test_act_and_mul(
|
|||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
||||||
if activation == "silu":
|
if activation == "silu_and_mul":
|
||||||
layer = SiluAndMul()
|
layer = SiluAndMul()
|
||||||
fn = torch.ops._C.silu_and_mul
|
fn = torch.ops._C.silu_and_mul
|
||||||
|
if activation == "mul_and_silu":
|
||||||
|
layer = MulAndSilu()
|
||||||
|
fn = torch.ops._C.mul_and_silu
|
||||||
elif activation == "gelu":
|
elif activation == "gelu":
|
||||||
layer = GeluAndMul(approximate="none")
|
layer = GeluAndMul(approximate="none")
|
||||||
fn = torch.ops._C.gelu_and_mul
|
fn = torch.ops._C.gelu_and_mul
|
||||||
@ -55,8 +60,9 @@ def test_act_and_mul(
|
|||||||
fn = torch.ops._C.fatrelu_and_mul
|
fn = torch.ops._C.fatrelu_and_mul
|
||||||
out = layer(x)
|
out = layer(x)
|
||||||
ref_out = layer.forward_native(x)
|
ref_out = layer.forward_native(x)
|
||||||
# The SiLU, GELU and FatReLU implementations are equivalent to the native
|
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
|
||||||
# PyTorch implementations, so we can do exact comparison.
|
# equivalent to the native PyTorch implementations, so we can do exact
|
||||||
|
# comparison.
|
||||||
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
|
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
|
||||||
|
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
|
@ -87,6 +87,41 @@ class SiluAndMul(CustomOp):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("mul_and_silu")
|
||||||
|
class MulAndSilu(CustomOp):
|
||||||
|
"""An activation function for SwiGLU.
|
||||||
|
|
||||||
|
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||||
|
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||||
|
self.op = torch.ops._C.mul_and_silu
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
from vllm._ipex_ops import ipex_ops
|
||||||
|
self.op = ipex_ops.silu_and_mul
|
||||||
|
|
||||||
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
return x[..., :d] * F.silu(x[..., d:])
|
||||||
|
|
||||||
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
|
self.op(out, x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
# TODO implement forward_xpu for MulAndSilu
|
||||||
|
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("gelu_and_mul")
|
@CustomOp.register("gelu_and_mul")
|
||||||
class GeluAndMul(CustomOp):
|
class GeluAndMul(CustomOp):
|
||||||
"""An activation function for GeGLU.
|
"""An activation function for GeGLU.
|
||||||
|
@ -23,7 +23,8 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
|||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||||
InputContext, token_inputs)
|
InputContext, token_inputs)
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
|
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
|
||||||
|
SiluAndMul)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
@ -462,15 +463,6 @@ class MolmoAttention(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class SwiGLU(nn.Module):
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
x, gate = x.chunk(2, dim=-1)
|
|
||||||
# Note that the order is reversed compared to
|
|
||||||
# SiluAndMul.
|
|
||||||
return x * F.silu(gate)
|
|
||||||
|
|
||||||
|
|
||||||
class LanuageModelMLP(nn.Module):
|
class LanuageModelMLP(nn.Module):
|
||||||
"""Molmo's LLM mlp."""
|
"""Molmo's LLM mlp."""
|
||||||
|
|
||||||
@ -489,7 +481,7 @@ class LanuageModelMLP(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
# Activation function.
|
# Activation function.
|
||||||
self.act_fn = SwiGLU()
|
self.act_fn = MulAndSilu()
|
||||||
# Feed-forward output projection.
|
# Feed-forward output projection.
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
|
@ -16,7 +16,7 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||||
@ -248,15 +248,6 @@ class StackAudioFrames(nn.Module):
|
|||||||
return audio_embeds
|
return audio_embeds
|
||||||
|
|
||||||
|
|
||||||
class FlippedSiluAndMul(SiluAndMul):
|
|
||||||
"""Ultravox is trained with SwiGLU with flipped halves."""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
a, b = x.chunk(2, dim=-1)
|
|
||||||
flipped = torch.cat((b, a), dim=-1)
|
|
||||||
return super().forward(flipped)
|
|
||||||
|
|
||||||
|
|
||||||
class UltravoxProjector(nn.Module):
|
class UltravoxProjector(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: UltravoxConfig):
|
def __init__(self, config: UltravoxConfig):
|
||||||
@ -269,7 +260,7 @@ class UltravoxProjector(nn.Module):
|
|||||||
dim = self.hidden_dim
|
dim = self.hidden_dim
|
||||||
|
|
||||||
if config.projector_act == "swiglu":
|
if config.projector_act == "swiglu":
|
||||||
self.act = FlippedSiluAndMul()
|
self.act = MulAndSilu()
|
||||||
dim = dim // 2
|
dim = dim // 2
|
||||||
else:
|
else:
|
||||||
self.act = get_act_fn(config.projector_act)
|
self.act = get_act_fn(config.projector_act)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user