Optimize moe_align_block_size for deepseek_v3 (#12850)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
bffddd9a05
commit
2344192a55
@ -198,26 +198,27 @@ __global__ void moe_align_block_size_global_mem_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// taken from
|
// taken from
|
||||||
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
|
// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void sgl_moe_align_block_size_kernel(
|
__global__ void sgl_moe_align_block_size_kernel(
|
||||||
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
||||||
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
|
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
|
||||||
int32_t block_size, size_t numel, int32_t* cumsum) {
|
int32_t block_size, size_t numel, int32_t* cumsum) {
|
||||||
__shared__ int32_t shared_counts[32][8];
|
__shared__ int32_t shared_counts[32][8];
|
||||||
__shared__ int32_t local_offsets[256];
|
|
||||||
|
|
||||||
const int warp_id = threadIdx.x / 32;
|
const int warp_id = threadIdx.x / 32;
|
||||||
const int lane_id = threadIdx.x % 32;
|
|
||||||
const int experts_per_warp = 8;
|
const int experts_per_warp = 8;
|
||||||
const int my_expert_start = warp_id * experts_per_warp;
|
const int my_expert_start = warp_id * experts_per_warp;
|
||||||
|
|
||||||
|
// Initialize shared_counts for this warp's experts
|
||||||
for (int i = 0; i < experts_per_warp; ++i) {
|
for (int i = 0; i < experts_per_warp; ++i) {
|
||||||
if (my_expert_start + i < num_experts) {
|
if (my_expert_start + i < num_experts) {
|
||||||
shared_counts[warp_id][i] = 0;
|
shared_counts[warp_id][i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||||
|
|
||||||
@ -230,6 +231,7 @@ __global__ void sgl_moe_align_block_size_kernel(
|
|||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
// Single thread computes cumulative sum and total tokens
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
cumsum[0] = 0;
|
cumsum[0] = 0;
|
||||||
for (int i = 1; i <= num_experts; ++i) {
|
for (int i = 1; i <= num_experts; ++i) {
|
||||||
@ -246,19 +248,28 @@ __global__ void sgl_moe_align_block_size_kernel(
|
|||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
// Assign expert IDs to blocks
|
||||||
if (threadIdx.x < num_experts) {
|
if (threadIdx.x < num_experts) {
|
||||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||||
i += block_size) {
|
i += block_size) {
|
||||||
expert_ids[i / block_size] = threadIdx.x;
|
expert_ids[i / block_size] = threadIdx.x;
|
||||||
}
|
}
|
||||||
local_offsets[threadIdx.x] = cumsum[threadIdx.x];
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
__syncthreads();
|
// taken from
|
||||||
|
// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids,
|
||||||
|
int32_t* sorted_token_ids,
|
||||||
|
int32_t* cumsum_buffer,
|
||||||
|
size_t numel) {
|
||||||
|
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const size_t stride = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
for (size_t i = tid; i < numel; i += stride) {
|
||||||
int32_t expert_id = topk_ids[i];
|
int32_t expert_id = topk_ids[i];
|
||||||
int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
|
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
|
||||||
sorted_token_ids[rank_post_pad] = i;
|
sorted_token_ids[rank_post_pad] = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -377,23 +388,34 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|||||||
torch::Tensor experts_ids,
|
torch::Tensor experts_ids,
|
||||||
torch::Tensor num_tokens_post_pad) {
|
torch::Tensor num_tokens_post_pad) {
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
TORCH_CHECK(num_experts == 256,
|
||||||
|
"sgl_moe_align_block_size kernel only supports deepseek v3.");
|
||||||
|
|
||||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||||
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
|
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
|
||||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
// calc needed amount of shared mem for `cumsum` tensors
|
||||||
// tensors
|
|
||||||
auto options_int =
|
auto options_int =
|
||||||
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
|
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
|
||||||
// torch::Tensor token_cnts_buffer =
|
|
||||||
// torch::empty({(num_experts + 1) * num_experts}, options_int);
|
|
||||||
torch::Tensor cumsum_buffer =
|
torch::Tensor cumsum_buffer =
|
||||||
torch::empty({num_experts + 1}, options_int);
|
torch::zeros({num_experts + 1}, options_int);
|
||||||
|
|
||||||
auto kernel = vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
|
auto align_kernel =
|
||||||
kernel<<<1, 1024, 0, stream>>>(
|
vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
|
||||||
|
align_kernel<<<1, 1024, 0, stream>>>(
|
||||||
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||||
experts_ids.data_ptr<int32_t>(),
|
experts_ids.data_ptr<int32_t>(),
|
||||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||||
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
||||||
|
|
||||||
|
const int block_threads = 256;
|
||||||
|
const int num_blocks =
|
||||||
|
(topk_ids.numel() + block_threads - 1) / block_threads;
|
||||||
|
const int max_blocks = 65535;
|
||||||
|
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||||
|
auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel<scalar_t>;
|
||||||
|
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
|
||||||
|
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||||
|
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -596,7 +596,7 @@ def moe_align_block_size(
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=topk_ids.device)
|
device=topk_ids.device)
|
||||||
if num_experts >= 224:
|
if num_experts >= 224:
|
||||||
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
|
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256:
|
||||||
moe_align_block_size_triton(
|
moe_align_block_size_triton(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
@ -606,6 +606,7 @@ def moe_align_block_size(
|
|||||||
num_tokens_post_pad,
|
num_tokens_post_pad,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Currently requires num_experts=256
|
||||||
ops.sgl_moe_align_block_size(
|
ops.sgl_moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user