[Performance][Kernel] Fused_moe Performance Improvement (#9384)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu 2024-10-24 17:37:52 -05:00 committed by GitHub
parent e26d37a185
commit 59449095ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 118 additions and 38 deletions

View File

@ -195,7 +195,6 @@ set(VLLM_EXT_SRC
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp")
@ -405,6 +404,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu")
set_gencode_flags_for_srcs(

View File

@ -1,15 +1,17 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
namespace vllm {
namespace moe {
namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
@ -32,10 +34,10 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
extern __shared__ int32_t shared_mem[];
int32_t* tokens_cnts =
shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
int32_t* cumsum =
shared_mem + (num_experts + 1) *
num_experts; // 1d tensor with shape (num_experts + 1)
shared_mem +
(blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
@ -53,10 +55,12 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}
__syncthreads();
@ -79,9 +83,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}
/**
@ -106,6 +112,24 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
}
}
template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., topk, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
scalar_t x = 0.0;
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]);
}
out[token_idx * d + idx] = x;
}
}
} // namespace moe
} // namespace vllm
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
@ -117,18 +141,62 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem =
((num_experts + 1) * num_experts + (num_experts + 1)) *
((num_thread + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);
// set dynamic shared mem
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<1, num_experts, shared_mem, stream>>>(
kernel<<<1, num_thread, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel());
});
}
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
torch::Tensor& output) // [num_tokens, hidden_size]
{
const int hidden_size = input.size(-1);
const int num_tokens = output.numel() / hidden_size;
const int topk = input.size(1);
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (topk) {
case 2:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
case 3:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 3><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
case 4:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
default:
at::sum_out(output, input, 1);
break;
}
}

View File

@ -5,3 +5,10 @@
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
void moe_sum(torch::Tensor& input, torch::Tensor& output);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);

View File

@ -8,6 +8,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m.def("moe_sum(Tensor! input, Tensor output) -> ()");
m.impl("moe_sum", torch::kCUDA, &moe_sum);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
#ifndef USE_ROCM
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "

View File

@ -145,11 +145,6 @@ void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
c10::optional<torch::Tensor> const& scale_ub);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B,
const torch::Tensor& C,

View File

@ -336,15 +336,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
ops.def(
"moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()");
ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"

View File

@ -19,7 +19,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.scalar_type import scalar_types
from vllm.utils import seed_everything
from vllm.utils import is_hip, seed_everything
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@ -103,6 +103,7 @@ def test_mixtral_moe(dtype: torch.dtype):
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(is_hip(), reason="Skip for rocm")
def test_fused_marlin_moe(
m: int,
n: int,
@ -255,6 +256,7 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(is_hip(), reason="Skip for rocm")
def test_single_marlin_moe_multiply(
m: int,
n: int,
@ -345,6 +347,6 @@ def test_moe_align_block_size_opcheck():
dtype=torch.int32,
device=topk_ids.device)
opcheck(torch.ops._C.moe_align_block_size,
opcheck(torch.ops._moe_C.moe_align_block_size,
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
num_tokens_post_pad))

View File

@ -813,13 +813,17 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
# moe
def moe_sum(input: torch.Tensor, output: torch.Tensor):
torch.ops._moe_C.moe_sum(input, output)
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor) -> None:
torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids,
num_tokens_post_pad)
torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids,
num_tokens_post_pad)
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,

View File

@ -589,9 +589,8 @@ def fused_experts(hidden_states: torch.Tensor,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16)
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states