Optimize moe_align_block_size for deepseek_v3 (#12850)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-02-13 18:43:37 -05:00 committed by GitHub
parent bffddd9a05
commit 2344192a55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 15 deletions

View File

@ -198,26 +198,27 @@ __global__ void moe_align_block_size_global_mem_kernel(
} }
// taken from // taken from
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a // https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
template <typename scalar_t> template <typename scalar_t>
__global__ void sgl_moe_align_block_size_kernel( __global__ void sgl_moe_align_block_size_kernel(
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel, int32_t* cumsum) { int32_t block_size, size_t numel, int32_t* cumsum) {
__shared__ int32_t shared_counts[32][8]; __shared__ int32_t shared_counts[32][8];
__shared__ int32_t local_offsets[256];
const int warp_id = threadIdx.x / 32; const int warp_id = threadIdx.x / 32;
const int lane_id = threadIdx.x % 32;
const int experts_per_warp = 8; const int experts_per_warp = 8;
const int my_expert_start = warp_id * experts_per_warp; const int my_expert_start = warp_id * experts_per_warp;
// Initialize shared_counts for this warp's experts
for (int i = 0; i < experts_per_warp; ++i) { for (int i = 0; i < experts_per_warp; ++i) {
if (my_expert_start + i < num_experts) { if (my_expert_start + i < num_experts) {
shared_counts[warp_id][i] = 0; shared_counts[warp_id][i] = 0;
} }
} }
__syncthreads();
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread; const size_t start_idx = threadIdx.x * tokens_per_thread;
@ -230,6 +231,7 @@ __global__ void sgl_moe_align_block_size_kernel(
__syncthreads(); __syncthreads();
// Single thread computes cumulative sum and total tokens
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
cumsum[0] = 0; cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) { for (int i = 1; i <= num_experts; ++i) {
@ -246,19 +248,28 @@ __global__ void sgl_moe_align_block_size_kernel(
__syncthreads(); __syncthreads();
// Assign expert IDs to blocks
if (threadIdx.x < num_experts) { if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) { i += block_size) {
expert_ids[i / block_size] = threadIdx.x; expert_ids[i / block_size] = threadIdx.x;
} }
local_offsets[threadIdx.x] = cumsum[threadIdx.x];
} }
}
__syncthreads(); // taken from
// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
template <typename scalar_t>
__global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids,
int32_t* sorted_token_ids,
int32_t* cumsum_buffer,
size_t numel) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i]; int32_t expert_id = topk_ids[i];
int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1); int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
sorted_token_ids[rank_post_pad] = i; sorted_token_ids[rank_post_pad] = i;
} }
} }
@ -377,23 +388,34 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) { torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(num_experts == 256,
"sgl_moe_align_block_size kernel only supports deepseek v3.");
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` // calc needed amount of shared mem for `cumsum` tensors
// tensors
auto options_int = auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
// torch::Tensor token_cnts_buffer =
// torch::empty({(num_experts + 1) * num_experts}, options_int);
torch::Tensor cumsum_buffer = torch::Tensor cumsum_buffer =
torch::empty({num_experts + 1}, options_int); torch::zeros({num_experts + 1}, options_int);
auto kernel = vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>; auto align_kernel =
kernel<<<1, 1024, 0, stream>>>( vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
align_kernel<<<1, 1024, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(), topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>()); topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
const int block_threads = 256;
const int num_blocks =
(topk_ids.numel() + block_threads - 1) / block_threads;
const int max_blocks = 65535;
const int actual_blocks = std::min(num_blocks, max_blocks);
auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel<scalar_t>;
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
}); });
} }

View File

@ -596,7 +596,7 @@ def moe_align_block_size(
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device) device=topk_ids.device)
if num_experts >= 224: if num_experts >= 224:
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256:
moe_align_block_size_triton( moe_align_block_size_triton(
topk_ids, topk_ids,
num_experts, num_experts,
@ -606,6 +606,7 @@ def moe_align_block_size(
num_tokens_post_pad, num_tokens_post_pad,
) )
else: else:
# Currently requires num_experts=256
ops.sgl_moe_align_block_size( ops.sgl_moe_align_block_size(
topk_ids, topk_ids,
num_experts, num_experts,