[Kernel] port sgl moe_align_block_size kernels (#12574)
sgl_moe_align_block_size is based on:ded9fcd09a
moe_align_block_size is based on:ba5112ff69
Signed-off-by: Yang Chen <yangche@fb.com>
This commit is contained in:
parent
326fcc8b9f
commit
95460fc513
@ -197,6 +197,72 @@ __global__ void moe_align_block_size_global_mem_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// taken from
|
||||||
|
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void sgl_moe_align_block_size_kernel(
|
||||||
|
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
||||||
|
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
|
||||||
|
int32_t block_size, size_t numel, int32_t* cumsum) {
|
||||||
|
__shared__ int32_t shared_counts[32][8];
|
||||||
|
__shared__ int32_t local_offsets[256];
|
||||||
|
|
||||||
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
const int experts_per_warp = 8;
|
||||||
|
const int my_expert_start = warp_id * experts_per_warp;
|
||||||
|
|
||||||
|
for (int i = 0; i < experts_per_warp; ++i) {
|
||||||
|
if (my_expert_start + i < num_experts) {
|
||||||
|
shared_counts[warp_id][i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||||
|
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||||
|
|
||||||
|
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||||
|
int expert_id = topk_ids[i];
|
||||||
|
int warp_idx = expert_id / experts_per_warp;
|
||||||
|
int expert_offset = expert_id % experts_per_warp;
|
||||||
|
atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
cumsum[0] = 0;
|
||||||
|
for (int i = 1; i <= num_experts; ++i) {
|
||||||
|
int expert_count = 0;
|
||||||
|
int warp_idx = (i - 1) / experts_per_warp;
|
||||||
|
int expert_offset = (i - 1) % experts_per_warp;
|
||||||
|
expert_count = shared_counts[warp_idx][expert_offset];
|
||||||
|
|
||||||
|
cumsum[i] =
|
||||||
|
cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
|
||||||
|
}
|
||||||
|
*total_tokens_post_pad = cumsum[num_experts];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (threadIdx.x < num_experts) {
|
||||||
|
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||||
|
i += block_size) {
|
||||||
|
expert_ids[i / block_size] = threadIdx.x;
|
||||||
|
}
|
||||||
|
local_offsets[threadIdx.x] = cumsum[threadIdx.x];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||||
|
int32_t expert_id = topk_ids[i];
|
||||||
|
int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
|
||||||
|
sorted_token_ids[rank_post_pad] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename scalar_t, int TOPK>
|
template <typename scalar_t, int TOPK>
|
||||||
__global__ void moe_sum_kernel(
|
__global__ void moe_sum_kernel(
|
||||||
scalar_t* __restrict__ out, // [..., d]
|
scalar_t* __restrict__ out, // [..., d]
|
||||||
@ -305,6 +371,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||||
|
int64_t block_size,
|
||||||
|
torch::Tensor sorted_token_ids,
|
||||||
|
torch::Tensor experts_ids,
|
||||||
|
torch::Tensor num_tokens_post_pad) {
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||||
|
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
|
||||||
|
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||||
|
// tensors
|
||||||
|
auto options_int =
|
||||||
|
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
|
||||||
|
// torch::Tensor token_cnts_buffer =
|
||||||
|
// torch::empty({(num_experts + 1) * num_experts}, options_int);
|
||||||
|
torch::Tensor cumsum_buffer =
|
||||||
|
torch::empty({num_experts + 1}, options_int);
|
||||||
|
|
||||||
|
auto kernel = vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
|
||||||
|
kernel<<<1, 1024, 0, stream>>>(
|
||||||
|
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||||
|
experts_ids.data_ptr<int32_t>(),
|
||||||
|
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||||
|
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
||||||
torch::Tensor& output) // [num_tokens, hidden_size]
|
torch::Tensor& output) // [num_tokens, hidden_size]
|
||||||
{
|
{
|
||||||
|
@ -12,3 +12,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|||||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||||
torch::Tensor experts_ids,
|
torch::Tensor experts_ids,
|
||||||
torch::Tensor num_tokens_post_pad);
|
torch::Tensor num_tokens_post_pad);
|
||||||
|
|
||||||
|
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||||
|
int64_t block_size,
|
||||||
|
torch::Tensor sorted_token_ids,
|
||||||
|
torch::Tensor experts_ids,
|
||||||
|
torch::Tensor num_tokens_post_pad);
|
||||||
|
@ -22,6 +22,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|||||||
" Tensor! num_tokens_post_pad) -> ()");
|
" Tensor! num_tokens_post_pad) -> ()");
|
||||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||||
|
|
||||||
|
// temporarily adapted from
|
||||||
|
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
|
||||||
|
m.def(
|
||||||
|
"sgl_moe_align_block_size(Tensor topk_ids, int num_experts,"
|
||||||
|
" int block_size, Tensor! sorted_token_ids,"
|
||||||
|
" Tensor! experts_ids,"
|
||||||
|
" Tensor! num_tokens_post_pad) -> ()");
|
||||||
|
m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
m.def(
|
m.def(
|
||||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||||
|
@ -952,6 +952,15 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
|||||||
num_tokens_post_pad)
|
num_tokens_post_pad)
|
||||||
|
|
||||||
|
|
||||||
|
def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
||||||
|
block_size: int, sorted_token_ids: torch.Tensor,
|
||||||
|
experts_ids: torch.Tensor,
|
||||||
|
num_tokens_post_pad: torch.Tensor) -> None:
|
||||||
|
torch.ops._moe_C.sgl_moe_align_block_size(topk_ids, num_experts,
|
||||||
|
block_size, sorted_token_ids,
|
||||||
|
experts_ids, num_tokens_post_pad)
|
||||||
|
|
||||||
|
|
||||||
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
token_expert_indicies: torch.Tensor,
|
token_expert_indicies: torch.Tensor,
|
||||||
gating_output: float) -> None:
|
gating_output: float) -> None:
|
||||||
|
@ -82,6 +82,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_MLA_DISABLE: bool = False
|
VLLM_MLA_DISABLE: bool = False
|
||||||
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
||||||
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
|
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
|
||||||
|
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -531,7 +532,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# matrices to match the activation type. This can lead to higher memory and
|
# matrices to match the activation type. This can lead to higher memory and
|
||||||
# compute usage but better preserves the accuracy of the original model.
|
# compute usage but better preserves the accuracy of the original model.
|
||||||
"VLLM_MLA_DISABLE_REQUANTIZATION":
|
"VLLM_MLA_DISABLE_REQUANTIZATION":
|
||||||
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0")))
|
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0"))),
|
||||||
|
|
||||||
|
# If set, vLLM will use the Triton implementation of moe_align_block_size,
|
||||||
|
# i.e. moe_align_block_size_triton in fused_moe.py.
|
||||||
|
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
@ -405,6 +405,144 @@ def fused_moe_kernel(
|
|||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def ceil_div(a, b):
|
||||||
|
return (a + b - 1) // b
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage1(
|
||||||
|
topk_ids_ptr,
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
numel: tl.constexpr,
|
||||||
|
tokens_per_thread: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
|
||||||
|
start_idx = pid * tokens_per_thread
|
||||||
|
|
||||||
|
off_c = (pid + 1) * num_experts
|
||||||
|
|
||||||
|
for i in range(tokens_per_thread):
|
||||||
|
if start_idx + i < numel:
|
||||||
|
idx = tl.load(topk_ids_ptr + start_idx + i)
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
||||||
|
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage2(
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
|
||||||
|
last_cnt = 0
|
||||||
|
for i in range(1, num_experts + 1):
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
||||||
|
last_cnt = last_cnt + token_cnt
|
||||||
|
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage3(
|
||||||
|
total_tokens_post_pad_ptr,
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
cumsum_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
):
|
||||||
|
last_cumsum = 0
|
||||||
|
off_cnt = num_experts * num_experts
|
||||||
|
for i in range(1, num_experts + 1):
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
||||||
|
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
||||||
|
tl.store(cumsum_ptr + i, last_cumsum)
|
||||||
|
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage4(
|
||||||
|
topk_ids_ptr,
|
||||||
|
sorted_token_ids_ptr,
|
||||||
|
expert_ids_ptr,
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
cumsum_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
numel: tl.constexpr,
|
||||||
|
tokens_per_thread: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
start_idx = tl.load(cumsum_ptr + pid)
|
||||||
|
end_idx = tl.load(cumsum_ptr + pid + 1)
|
||||||
|
|
||||||
|
for i in range(start_idx, end_idx, block_size):
|
||||||
|
tl.store(expert_ids_ptr + i // block_size, pid)
|
||||||
|
|
||||||
|
start_idx = pid * tokens_per_thread
|
||||||
|
off_t = pid * num_experts
|
||||||
|
|
||||||
|
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
|
||||||
|
numel)):
|
||||||
|
expert_id = tl.load(topk_ids_ptr + i)
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
||||||
|
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
||||||
|
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
||||||
|
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
||||||
|
|
||||||
|
|
||||||
|
# Triton implementation based on:
|
||||||
|
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
|
||||||
|
def moe_align_block_size_triton(
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
block_size: int,
|
||||||
|
sorted_token_ids: torch.Tensor,
|
||||||
|
expert_ids: torch.Tensor,
|
||||||
|
num_tokens_post_pad: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
numel = topk_ids.numel()
|
||||||
|
grid = (num_experts, )
|
||||||
|
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_ids.device)
|
||||||
|
cumsum = torch.zeros((num_experts + 1, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_ids.device)
|
||||||
|
tokens_per_thread = ceil_div(numel, num_experts)
|
||||||
|
|
||||||
|
moe_align_block_size_stage1[grid](
|
||||||
|
topk_ids,
|
||||||
|
tokens_cnts,
|
||||||
|
num_experts,
|
||||||
|
numel,
|
||||||
|
tokens_per_thread,
|
||||||
|
)
|
||||||
|
moe_align_block_size_stage2[grid](
|
||||||
|
tokens_cnts,
|
||||||
|
num_experts,
|
||||||
|
)
|
||||||
|
moe_align_block_size_stage3[(1, )](
|
||||||
|
num_tokens_post_pad,
|
||||||
|
tokens_cnts,
|
||||||
|
cumsum,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
)
|
||||||
|
moe_align_block_size_stage4[grid](
|
||||||
|
topk_ids,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
tokens_cnts,
|
||||||
|
cumsum,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
numel,
|
||||||
|
tokens_per_thread,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def moe_align_block_size(
|
def moe_align_block_size(
|
||||||
topk_ids: torch.Tensor, block_size: int,
|
topk_ids: torch.Tensor, block_size: int,
|
||||||
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
@ -457,6 +595,26 @@ def moe_align_block_size(
|
|||||||
num_tokens_post_pad = torch.empty((1),
|
num_tokens_post_pad = torch.empty((1),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=topk_ids.device)
|
device=topk_ids.device)
|
||||||
|
if num_experts >= 224:
|
||||||
|
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
|
||||||
|
moe_align_block_size_triton(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_pad,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ops.sgl_moe_align_block_size(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_pad,
|
||||||
|
)
|
||||||
|
else:
|
||||||
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
||||||
expert_ids, num_tokens_post_pad)
|
expert_ids, num_tokens_post_pad)
|
||||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||||
|
Loading…
x
Reference in New Issue
Block a user