[Performance][Kernel] Fused_moe Performance Improvement (#9384)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
parent
e26d37a185
commit
59449095ab
@ -195,7 +195,6 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||||
"csrc/quantization/fp8/common.cu"
|
"csrc/quantization/fp8/common.cu"
|
||||||
"csrc/cuda_utils_kernels.cu"
|
"csrc/cuda_utils_kernels.cu"
|
||||||
"csrc/moe_align_block_size_kernels.cu"
|
|
||||||
"csrc/prepare_inputs/advance_step.cu"
|
"csrc/prepare_inputs/advance_step.cu"
|
||||||
"csrc/torch_bindings.cpp")
|
"csrc/torch_bindings.cpp")
|
||||||
|
|
||||||
@ -405,6 +404,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
|||||||
|
|
||||||
set(VLLM_MOE_EXT_SRC
|
set(VLLM_MOE_EXT_SRC
|
||||||
"csrc/moe/torch_bindings.cpp"
|
"csrc/moe/torch_bindings.cpp"
|
||||||
|
"csrc/moe/moe_align_sum_kernels.cu"
|
||||||
"csrc/moe/topk_softmax_kernels.cu")
|
"csrc/moe/topk_softmax_kernels.cu")
|
||||||
|
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <THC/THCAtomics.cuh>
|
#include <THC/THCAtomics.cuh>
|
||||||
|
|
||||||
#include "cuda_compat.h"
|
#include "../cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "../dispatch_utils.h"
|
||||||
|
|
||||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
namespace moe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
|
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
|
||||||
@ -32,10 +34,10 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
|||||||
extern __shared__ int32_t shared_mem[];
|
extern __shared__ int32_t shared_mem[];
|
||||||
|
|
||||||
int32_t* tokens_cnts =
|
int32_t* tokens_cnts =
|
||||||
shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
|
shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
|
||||||
int32_t* cumsum =
|
int32_t* cumsum =
|
||||||
shared_mem + (num_experts + 1) *
|
shared_mem +
|
||||||
num_experts; // 1d tensor with shape (num_experts + 1)
|
(blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
|
||||||
|
|
||||||
for (int i = 0; i < num_experts; ++i) {
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||||
@ -53,11 +55,13 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// For each expert we accumulate the token counts from the different threads.
|
// For each expert we accumulate the token counts from the different threads.
|
||||||
|
if (threadIdx.x < num_experts) {
|
||||||
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
||||||
for (int i = 1; i <= blockDim.x; ++i) {
|
for (int i = 1; i <= blockDim.x; ++i) {
|
||||||
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
|
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
|
||||||
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
|
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@ -79,10 +83,12 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
|||||||
* For each expert, each thread processes the tokens of the corresponding
|
* For each expert, each thread processes the tokens of the corresponding
|
||||||
* blocks and stores the corresponding expert_id for each block.
|
* blocks and stores the corresponding expert_id for each block.
|
||||||
*/
|
*/
|
||||||
|
if (threadIdx.x < num_experts) {
|
||||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||||
i += block_size) {
|
i += block_size) {
|
||||||
expert_ids[i / block_size] = threadIdx.x;
|
expert_ids[i / block_size] = threadIdx.x;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Each thread processes a token shard, calculating the index of each token
|
* Each thread processes a token shard, calculating the index of each token
|
||||||
@ -106,6 +112,24 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
|||||||
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int TOPK>
|
||||||
|
__global__ void moe_sum_kernel(
|
||||||
|
scalar_t* __restrict__ out, // [..., d]
|
||||||
|
const scalar_t* __restrict__ input, // [..., topk, d]
|
||||||
|
const int d) {
|
||||||
|
const int64_t token_idx = blockIdx.x;
|
||||||
|
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
|
scalar_t x = 0.0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int k = 0; k < TOPK; ++k) {
|
||||||
|
x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]);
|
||||||
|
}
|
||||||
|
out[token_idx * d + idx] = x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace moe
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||||
@ -117,18 +141,62 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|||||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||||
// tensors
|
// tensors
|
||||||
|
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||||
const int32_t shared_mem =
|
const int32_t shared_mem =
|
||||||
((num_experts + 1) * num_experts + (num_experts + 1)) *
|
((num_thread + 1) * num_experts + (num_experts + 1)) *
|
||||||
sizeof(int32_t);
|
sizeof(int32_t);
|
||||||
|
|
||||||
// set dynamic shared mem
|
// set dynamic shared mem
|
||||||
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
|
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
|
||||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||||
(void*)kernel, shared_mem));
|
(void*)kernel, shared_mem));
|
||||||
kernel<<<1, num_experts, shared_mem, stream>>>(
|
kernel<<<1, num_thread, shared_mem, stream>>>(
|
||||||
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||||
experts_ids.data_ptr<int32_t>(),
|
experts_ids.data_ptr<int32_t>(),
|
||||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||||
topk_ids.numel());
|
topk_ids.numel());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
||||||
|
torch::Tensor& output) // [num_tokens, hidden_size]
|
||||||
|
{
|
||||||
|
const int hidden_size = input.size(-1);
|
||||||
|
const int num_tokens = output.numel() / hidden_size;
|
||||||
|
const int topk = input.size(1);
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
switch (topk) {
|
||||||
|
case 2:
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
||||||
|
vllm::moe::moe_sum_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||||
|
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||||
|
hidden_size);
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 3:
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
||||||
|
vllm::moe::moe_sum_kernel<scalar_t, 3><<<grid, block, 0, stream>>>(
|
||||||
|
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||||
|
hidden_size);
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 4:
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
||||||
|
vllm::moe::moe_sum_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||||
|
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||||
|
hidden_size);
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
at::sum_out(output, input, 1);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
@ -5,3 +5,10 @@
|
|||||||
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
||||||
torch::Tensor& token_expert_indices,
|
torch::Tensor& token_expert_indices,
|
||||||
torch::Tensor& gating_output);
|
torch::Tensor& gating_output);
|
||||||
|
|
||||||
|
void moe_sum(torch::Tensor& input, torch::Tensor& output);
|
||||||
|
|
||||||
|
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||||
|
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||||
|
torch::Tensor experts_ids,
|
||||||
|
torch::Tensor num_tokens_post_pad);
|
||||||
|
@ -8,6 +8,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|||||||
"token_expert_indices, Tensor gating_output) -> ()");
|
"token_expert_indices, Tensor gating_output) -> ()");
|
||||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||||
|
|
||||||
|
// Calculate the result of moe by summing up the partial results
|
||||||
|
// from all selected experts.
|
||||||
|
m.def("moe_sum(Tensor! input, Tensor output) -> ()");
|
||||||
|
m.impl("moe_sum", torch::kCUDA, &moe_sum);
|
||||||
|
|
||||||
|
// Aligning the number of tokens to be processed by each expert such
|
||||||
|
// that it is divisible by the block size.
|
||||||
|
m.def(
|
||||||
|
"moe_align_block_size(Tensor topk_ids, int num_experts,"
|
||||||
|
" int block_size, Tensor! sorted_token_ids,"
|
||||||
|
" Tensor! experts_ids,"
|
||||||
|
" Tensor! num_tokens_post_pad) -> ()");
|
||||||
|
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
m.def(
|
m.def(
|
||||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||||
|
@ -145,11 +145,6 @@ void dynamic_per_token_scaled_fp8_quant(
|
|||||||
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
|
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
|
||||||
c10::optional<torch::Tensor> const& scale_ub);
|
c10::optional<torch::Tensor> const& scale_ub);
|
||||||
|
|
||||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|
||||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
|
||||||
torch::Tensor experts_ids,
|
|
||||||
torch::Tensor num_tokens_post_pad);
|
|
||||||
|
|
||||||
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||||
const torch::Tensor& A, const torch::Tensor& B,
|
const torch::Tensor& A, const torch::Tensor& B,
|
||||||
const torch::Tensor& C,
|
const torch::Tensor& C,
|
||||||
|
@ -336,15 +336,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
||||||
&dynamic_per_token_scaled_fp8_quant);
|
&dynamic_per_token_scaled_fp8_quant);
|
||||||
|
|
||||||
// Aligning the number of tokens to be processed by each expert such
|
|
||||||
// that it is divisible by the block size.
|
|
||||||
ops.def(
|
|
||||||
"moe_align_block_size(Tensor topk_ids, int num_experts,"
|
|
||||||
" int block_size, Tensor! sorted_token_ids,"
|
|
||||||
" Tensor! experts_ids,"
|
|
||||||
" Tensor! num_tokens_post_pad) -> ()");
|
|
||||||
ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
|
||||||
|
|
||||||
// Compute int8 quantized tensor for given scaling factor.
|
// Compute int8 quantized tensor for given scaling factor.
|
||||||
ops.def(
|
ops.def(
|
||||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||||
|
@ -19,7 +19,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
|||||||
marlin_quantize)
|
marlin_quantize)
|
||||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils import seed_everything
|
from vllm.utils import is_hip, seed_everything
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
|
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
|
||||||
@ -103,6 +103,7 @@ def test_mixtral_moe(dtype: torch.dtype):
|
|||||||
@pytest.mark.parametrize("act_order", [True, False])
|
@pytest.mark.parametrize("act_order", [True, False])
|
||||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||||
|
@pytest.mark.skipif(is_hip(), reason="Skip for rocm")
|
||||||
def test_fused_marlin_moe(
|
def test_fused_marlin_moe(
|
||||||
m: int,
|
m: int,
|
||||||
n: int,
|
n: int,
|
||||||
@ -255,6 +256,7 @@ def test_fused_marlin_moe(
|
|||||||
@pytest.mark.parametrize("act_order", [True, False])
|
@pytest.mark.parametrize("act_order", [True, False])
|
||||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||||
|
@pytest.mark.skipif(is_hip(), reason="Skip for rocm")
|
||||||
def test_single_marlin_moe_multiply(
|
def test_single_marlin_moe_multiply(
|
||||||
m: int,
|
m: int,
|
||||||
n: int,
|
n: int,
|
||||||
@ -345,6 +347,6 @@ def test_moe_align_block_size_opcheck():
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=topk_ids.device)
|
device=topk_ids.device)
|
||||||
|
|
||||||
opcheck(torch.ops._C.moe_align_block_size,
|
opcheck(torch.ops._moe_C.moe_align_block_size,
|
||||||
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
|
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
|
||||||
num_tokens_post_pad))
|
num_tokens_post_pad))
|
||||||
|
@ -813,11 +813,15 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
# moe
|
# moe
|
||||||
|
def moe_sum(input: torch.Tensor, output: torch.Tensor):
|
||||||
|
torch.ops._moe_C.moe_sum(input, output)
|
||||||
|
|
||||||
|
|
||||||
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
||||||
block_size: int, sorted_token_ids: torch.Tensor,
|
block_size: int, sorted_token_ids: torch.Tensor,
|
||||||
experts_ids: torch.Tensor,
|
experts_ids: torch.Tensor,
|
||||||
num_tokens_post_pad: torch.Tensor) -> None:
|
num_tokens_post_pad: torch.Tensor) -> None:
|
||||||
torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
|
torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size,
|
||||||
sorted_token_ids, experts_ids,
|
sorted_token_ids, experts_ids,
|
||||||
num_tokens_post_pad)
|
num_tokens_post_pad)
|
||||||
|
|
||||||
|
@ -589,9 +589,8 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16)
|
use_int8_w8a16=use_int8_w8a16)
|
||||||
|
|
||||||
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
dim=1,
|
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
|
||||||
return out_hidden_states
|
return out_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user