[Performance][Kernel] Fused_moe Performance Improvement (#9384)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
parent
e26d37a185
commit
59449095ab
@ -195,7 +195,6 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||
"csrc/quantization/fp8/common.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/moe_align_block_size_kernels.cu"
|
||||
"csrc/prepare_inputs/advance_step.cu"
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
@ -405,6 +404,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
||||
|
||||
set(VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/torch_bindings.cpp"
|
||||
"csrc/moe/moe_align_sum_kernels.cu"
|
||||
"csrc/moe/topk_softmax_kernels.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
|
@ -1,15 +1,17 @@
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#include "../cuda_compat.h"
|
||||
#include "../dispatch_utils.h"
|
||||
|
||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
namespace {
|
||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
|
||||
@ -32,10 +34,10 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
|
||||
int32_t* tokens_cnts =
|
||||
shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
|
||||
shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
|
||||
int32_t* cumsum =
|
||||
shared_mem + (num_experts + 1) *
|
||||
num_experts; // 1d tensor with shape (num_experts + 1)
|
||||
shared_mem +
|
||||
(blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||
@ -53,10 +55,12 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
__syncthreads();
|
||||
|
||||
// For each expert we accumulate the token counts from the different threads.
|
||||
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
||||
for (int i = 1; i <= blockDim.x; ++i) {
|
||||
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
|
||||
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
|
||||
if (threadIdx.x < num_experts) {
|
||||
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
||||
for (int i = 1; i <= blockDim.x; ++i) {
|
||||
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
|
||||
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@ -79,9 +83,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
* For each expert, each thread processes the tokens of the corresponding
|
||||
* blocks and stores the corresponding expert_id for each block.
|
||||
*/
|
||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||
i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
if (threadIdx.x < num_experts) {
|
||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||
i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -106,6 +112,24 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int TOPK>
|
||||
__global__ void moe_sum_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., topk, d]
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
scalar_t x = 0.0;
|
||||
#pragma unroll
|
||||
for (int k = 0; k < TOPK; ++k) {
|
||||
x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]);
|
||||
}
|
||||
out[token_idx * d + idx] = x;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace moe
|
||||
} // namespace vllm
|
||||
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
@ -117,18 +141,62 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||
const int32_t shared_mem =
|
||||
((num_experts + 1) * num_experts + (num_experts + 1)) *
|
||||
((num_thread + 1) * num_experts + (num_experts + 1)) *
|
||||
sizeof(int32_t);
|
||||
|
||||
// set dynamic shared mem
|
||||
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
|
||||
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
|
||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||
(void*)kernel, shared_mem));
|
||||
kernel<<<1, num_experts, shared_mem, stream>>>(
|
||||
kernel<<<1, num_thread, shared_mem, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel());
|
||||
});
|
||||
}
|
||||
|
||||
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
||||
torch::Tensor& output) // [num_tokens, hidden_size]
|
||||
{
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = output.numel() / hidden_size;
|
||||
const int topk = input.size(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
switch (topk) {
|
||||
case 2:
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
||||
vllm::moe::moe_sum_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||
hidden_size);
|
||||
});
|
||||
break;
|
||||
|
||||
case 3:
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
||||
vllm::moe::moe_sum_kernel<scalar_t, 3><<<grid, block, 0, stream>>>(
|
||||
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||
hidden_size);
|
||||
});
|
||||
break;
|
||||
|
||||
case 4:
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
||||
vllm::moe::moe_sum_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||
hidden_size);
|
||||
});
|
||||
break;
|
||||
|
||||
default:
|
||||
at::sum_out(output, input, 1);
|
||||
break;
|
||||
}
|
||||
}
|
@ -5,3 +5,10 @@
|
||||
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
||||
torch::Tensor& token_expert_indices,
|
||||
torch::Tensor& gating_output);
|
||||
|
||||
void moe_sum(torch::Tensor& input, torch::Tensor& output);
|
||||
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
@ -8,6 +8,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"token_expert_indices, Tensor gating_output) -> ()");
|
||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||
|
||||
// Calculate the result of moe by summing up the partial results
|
||||
// from all selected experts.
|
||||
m.def("moe_sum(Tensor! input, Tensor output) -> ()");
|
||||
m.impl("moe_sum", torch::kCUDA, &moe_sum);
|
||||
|
||||
// Aligning the number of tokens to be processed by each expert such
|
||||
// that it is divisible by the block size.
|
||||
m.def(
|
||||
"moe_align_block_size(Tensor topk_ids, int num_experts,"
|
||||
" int block_size, Tensor! sorted_token_ids,"
|
||||
" Tensor! experts_ids,"
|
||||
" Tensor! num_tokens_post_pad) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
|
@ -145,11 +145,6 @@ void dynamic_per_token_scaled_fp8_quant(
|
||||
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
|
||||
c10::optional<torch::Tensor> const& scale_ub);
|
||||
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||
const torch::Tensor& A, const torch::Tensor& B,
|
||||
const torch::Tensor& C,
|
||||
|
@ -336,15 +336,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
||||
&dynamic_per_token_scaled_fp8_quant);
|
||||
|
||||
// Aligning the number of tokens to be processed by each expert such
|
||||
// that it is divisible by the block size.
|
||||
ops.def(
|
||||
"moe_align_block_size(Tensor topk_ids, int num_experts,"
|
||||
" int block_size, Tensor! sorted_token_ids,"
|
||||
" Tensor! experts_ids,"
|
||||
" Tensor! num_tokens_post_pad) -> ()");
|
||||
ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
|
@ -19,7 +19,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
marlin_quantize)
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.utils import is_hip, seed_everything
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
|
||||
@ -103,6 +103,7 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||
@pytest.mark.skipif(is_hip(), reason="Skip for rocm")
|
||||
def test_fused_marlin_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
@ -255,6 +256,7 @@ def test_fused_marlin_moe(
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||
@pytest.mark.skipif(is_hip(), reason="Skip for rocm")
|
||||
def test_single_marlin_moe_multiply(
|
||||
m: int,
|
||||
n: int,
|
||||
@ -345,6 +347,6 @@ def test_moe_align_block_size_opcheck():
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
|
||||
opcheck(torch.ops._C.moe_align_block_size,
|
||||
opcheck(torch.ops._moe_C.moe_align_block_size,
|
||||
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
|
||||
num_tokens_post_pad))
|
||||
|
@ -813,13 +813,17 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
||||
|
||||
|
||||
# moe
|
||||
def moe_sum(input: torch.Tensor, output: torch.Tensor):
|
||||
torch.ops._moe_C.moe_sum(input, output)
|
||||
|
||||
|
||||
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
||||
block_size: int, sorted_token_ids: torch.Tensor,
|
||||
experts_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor) -> None:
|
||||
torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
|
||||
sorted_token_ids, experts_ids,
|
||||
num_tokens_post_pad)
|
||||
torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size,
|
||||
sorted_token_ids, experts_ids,
|
||||
num_tokens_post_pad)
|
||||
|
||||
|
||||
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
|
@ -589,9 +589,8 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16)
|
||||
|
||||
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user